// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT-0 import org.apache.spark.SparkConf import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.types.StructType import scala.util.{Failure, Success, Try} object TestUtils { val param = DataGenParam() val batchParam = DataGenParam(targetDir="src/test/resources/results/raw") val dwhParam = DataGenParam(targetDir="src/test/resources/results/cleaned") val streamParam = DataGenParam(targetDir="src/test/resources/results/stream") val conf = new SparkConf() .setMaster("local[*]") .setAppName("data-generator") .set("spark.driver.allowMultipleContexts", "false") .set("spark.driver.memory","6g") .set("spark.sql.shuffle.partitions",param.parallelism.toString) .set("spark.ui.enabled", "true") .set("spark.ui.enabled", "true") .set("spark.sql.session.timeZone", "UTC") val spark = SparkSession.builder().config(conf).getOrCreate() def checkCsvSchema(spark: SparkSession, tablePath: String, tableSchema: StructType): Boolean = { val dfLoad = Try(spark.read.option("header","true").csv(tablePath)) checkSchema(dfLoad, tableSchema) } def checkParquetSchema(spark: SparkSession, tablePath: String, tableSchema: StructType): Boolean = { val dfLoad = Try(spark.read.parquet(tablePath)) checkSchema(dfLoad, tableSchema) } def checkJsonSchema(spark: SparkSession, tablePath: String, tableSchema: StructType): Boolean = { val dfLoad = Try(spark.read.json(tablePath)) checkSchema(dfLoad, tableSchema) } private def checkSchema(dfLoad: Try[Dataset[Row]], schema: StructType): Boolean = { dfLoad match { case Success(df) => df.printSchema; schema.printTreeString(); df.schema.diff(schema).isEmpty case Failure(f) => { println(f) false } } } }