package com.databricks.spark.sql.perf.mllib import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.{Row, SparkSession} class MLLibSuite extends FunSuite with BeforeAndAfterAll { private var sparkSession: SparkSession = _ var savedLevels: Map[String, Level] = _ override def beforeAll(): Unit = { super.beforeAll() sparkSession = SparkSession.builder.master("local[2]").appName("MLlib QA").getOrCreate() // Travis limits the size of the log file produced by a build. Because we do run a small // version of all the ML benchmarks in this suite, we produce a ton of logs. Here we set the // log level to ERROR, just for this suite, to avoid displeasing travis. savedLevels = Seq("akka", "org", "com.databricks").map { name => val logger = Logger.getLogger(name) val curLevel = logger.getLevel logger.setLevel(Level.ERROR) name -> curLevel }.toMap } override def afterAll(): Unit = { savedLevels.foreach { case (name, level) => Logger.getLogger(name).setLevel(level) } try { if (sparkSession != null) { sparkSession.stop() } // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") sparkSession = null } finally { super.afterAll() } } test("test MlLib benchmarks with mllib-small.yaml.") { val results = MLLib.run(yamlConfig = MLLib.smallConfig) val failures = results.na.drop(Seq("failure")) if (failures.count() > 0) { failures.select("name", "failure.*").collect().foreach { case Row(name: String, error: String, message: String) => println( s"""There as a failure in the benchmark for $name: | $error ${message.replace("\n", "\n ")} """.stripMargin) } fail("Unable to run all benchmarks successfully, see console output for more info.") } } test("test before benchmark methods for pipeline benchmarks.") { val benchmarks = MLLib.getBenchmarks(MLLib.getConf(yamlConfig = MLLib.smallConfig)) benchmarks.foreach { b => b.beforeBenchmark() } } }