tokuhirom's Blog

SparkSQL のクエリをユニットテストしたい

品質向上のために Spark クエリのユニットテストを実施したいという場合、JVM 言語で開発している場合には、Spark/hive をライブラリとしてロードできるから、容易に実装することができる。

dependencies {
    implementation 'org.apache.spark:spark-core_2.12:3.0.0'
    implementation 'org.apache.spark:spark-sql_2.12:3.0.0'
}

のように、関連するモジュールを依存に追加する。

以下のような、テストに利用するデータを json 形式などで用意する(spark は CSV, TSV などの形式も利用可能だから、好きなものを使えばよい)

{"name": "Nick",	"age":35,	"extra_fields": "{\"interests\":[\"car\", \"golf\"]}"}
{"name": "John",	"age":23}
{"name":"Cathy",	"age":44,	"extra_fields":"{\"interests\":[\"cooking\"]}"}

あとは、実際に spark session を作成し、local モードで spark を起動させれば良い。

import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

// test without
class SimpleTest {
    fun run() {
        val spark: SparkSession = SparkSession
            .builder()
            .appName("Java Spark SQL basic example") // your application name
            .config("spark.master", "local")  // run on local machine, single thread.
            .config("spark.ui.enabled", false)
            .getOrCreate()

        val resourcePath = javaClass.classLoader.getResource("test-data/people.json")!!.toString()
        println("++++ read csv from: $resourcePath")

        val df = spark.read()
            .json(resourcePath)
        df.show()
        df.printSchema()

        println("++++ create table")
        df.createTempView("people")

        println("++++ select")
        val sqlDF: Dataset<Row> = spark.sql("SELECT * FROM people")
        sqlDF.show(false)

        println("++++ select 2nd")
        val sqlDF2: Dataset<Row> = spark.sql("SELECT name, get_json_object(extra_fields, '$.interests') interests FROM people")
        sqlDF2.show()

        println("++++ select 3rd")
        val sqlDF3: Dataset<Row> = spark.sql("SELECT avg(age) avg_age FROM people")
        sqlDF3.show()
    }
}

fun main() {
    SimpleTest().run()
}

hive を使う場合

実際に動かすクエリが select * from your_db_name.your_table のようにDB 名を指定していて、そのクエリ自体を変えずにテストしたいという場合には、hive サポートを有効にする必要がある。

hive を使う場合、spark-hive を依存に追加する。

dependencies {
    implementation 'org.apache.spark:spark-core_2.12:3.0.0'
    implementation 'org.apache.spark:spark-sql_2.12:3.0.0'
    implementation 'org.apache.spark:spark-hive_2.12:3.0.0'
}

あとは以下のように DB を作って入れるだけ。

import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.SparkSession

class TestClass {
    fun run() {
        val warehouseLocation = createTempDir()
        println("++++ warehouseLocation=$warehouseLocation")
        val spark: SparkSession = SparkSession
            .builder()
            .appName("Java Spark SQL basic example") // your application name
            .config("spark.master", "local")  // run on local machine, single thread.
            .config("spark.sql.warehouse.dir", warehouseLocation.toString())
            .config("spark.ui.enabled", false)
            .enableHiveSupport()
            .getOrCreate()

        val resourcePath = javaClass.classLoader.getResource("test-data/people.json")!!.toString()
        println("++++ read csv from: $resourcePath")

        val df = spark.read()
            .json(resourcePath)
        df.show()
        df.printSchema()

        println("++++ create table")
        spark.sql("create database if not exists foo")
        df.write().mode(SaveMode.Overwrite).saveAsTable("foo.people")
        spark.sql("show tables").show()
        spark.sql("show create table foo.people").show(false)

        // If the type of data is the important thing, you need to write your schema by yourself.
        //        spark.sql("""drop table if exists `foo`.`people`""")
//        spark.sql("""
//        CREATE TABLE `foo`.`people` (
//            `name` STRING,
//            `age` long,
//            `extra_fields` STRING)
//        USING parquet""".trimIndent())
//        df.write().insertInto("foo.people")


        println("++++ select")
        val sqlDF: Dataset<Row> = spark.sql("SELECT * FROM foo.people")
        sqlDF.show(false)

        println("++++ select 2nd")
        val sqlDF2: Dataset<Row> = spark.sql("SELECT name, get_json_object(extra_fields, '$.interests') interests FROM foo.people")
        sqlDF2.show()

        println("++++ select 3rd")
        val sqlDF3: Dataset<Row> = spark.sql("SELECT avg(age) avg_age FROM foo.people")
        sqlDF3.show()
    }
}

fun main() {
    TestClass().run()
}

クエリを変更しなくていいというメリットがある一方で、hive にアクセスするので依存も増えるし、実行もめちゃくちゃ遅くなります。

df.write().mode(SaveMode.Overwrite).saveAsTable("foo.people")

のようにすると、df 側の型をみていい感じにテーブル定義してくれて便利だが、明示的に create table したいときは以下のようにしたほうがいいかも。

        spark.sql("""drop table if exists `foo`.`people`""")
        spark.sql("""
        CREATE TABLE `foo`.`people` (
            `name` STRING,
            `age` long,
            `extra_fields` STRING)
        USING parquet""".trimIndent())
        df.write().insertInto("foo.people")

両者の比較

hive を利用しない場合、上記コードは 4.427 sec 程度で終わりますが、hive を利用する場合は 19.676 sec 程度かかるようになります。 プロダクションコードのテストをする場合はこの差はそこそこでかいかも。

sample code

https://github.com/tokuhirom/sparksql-unittest