/* * Copyright 2015 Databricks Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.databricks.spark.sql.perf import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.{Success, Try, Failure => SFailure} import scala.util.control.NonFatal import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.SparkContext import com.databricks.spark.sql.perf.cpu._ /** * A collection of queries that test a particular aspect of Spark SQL. * * @param sqlContext An existing SQLContext. */ abstract class Benchmark( @transient val sqlContext: SQLContext) extends Serializable { import Benchmark._ def this() = this(SparkSession.builder.getOrCreate().sqlContext) val resultsLocation = sqlContext.getAllConfs.getOrElse( "spark.sql.perf.results", "/spark/sql/performance") protected def sparkContext = sqlContext.sparkContext protected implicit def toOption[A](a: A): Option[A] = Option(a) val buildInfo = Try(getClass.getClassLoader.loadClass("org.apache.spark.BuildInfo")).map { cls => cls.getMethods .filter(_.getReturnType == classOf[String]) .filterNot(_.getName == "toString") .map(m => m.getName -> m.invoke(cls).asInstanceOf[String]) .toMap }.getOrElse(Map.empty) def currentConfiguration = BenchmarkConfiguration( sqlConf = sqlContext.getAllConfs, sparkConf = sparkContext.getConf.getAll.toMap, defaultParallelism = sparkContext.defaultParallelism, buildInfo = buildInfo) val codegen = Variation("codegen", Seq("on", "off")) { case "off" => sqlContext.setConf("spark.sql.codegen", "false") case "on" => sqlContext.setConf("spark.sql.codegen", "true") } val unsafe = Variation("unsafe", Seq("on", "off")) { case "off" => sqlContext.setConf("spark.sql.unsafe.enabled", "false") case "on" => sqlContext.setConf("spark.sql.unsafe.enabled", "true") } val tungsten = Variation("tungsten", Seq("on", "off")) { case "off" => sqlContext.setConf("spark.sql.tungsten.enabled", "false") case "on" => sqlContext.setConf("spark.sql.tungsten.enabled", "true") } /** * Starts an experiment run with a given set of executions to run. * * @param executionsToRun a list of executions to run. * @param includeBreakdown If it is true, breakdown results of an execution will be recorded. * Setting it to true may significantly increase the time used to * run an execution. * @param iterations The number of iterations to run of each execution. * @param variations [[Variation]]s used in this run. The cross product of all variations will be * run for each execution * iteration. * @param tags Tags of this run. * @param timeout wait at most timeout milliseconds for each query, 0 means wait forever * @return It returns a ExperimentStatus object that can be used to * track the progress of this experiment run. */ def runExperiment( executionsToRun: Seq[Benchmarkable], includeBreakdown: Boolean = false, iterations: Int = 3, variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("true")) { _ => {} }), tags: Map[String, String] = Map.empty, timeout: Long = 0L, resultLocation: String = resultsLocation, forkThread: Boolean = true) = { new ExperimentStatus(executionsToRun, includeBreakdown, iterations, variations, tags, timeout, resultLocation, sqlContext, allTables, currentConfiguration, forkThread = forkThread) } import reflect.runtime._, universe._ import reflect.runtime._ import universe._ @transient private val runtimeMirror = universe.runtimeMirror(getClass.getClassLoader) @transient val myType = runtimeMirror.classSymbol(getClass).toType def singleTables = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Table]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Table]) def groupedTables = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Table]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Table]]) @transient lazy val allTables: Seq[Table] = (singleTables ++ groupedTables).toSeq def singleQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Benchmarkable]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Benchmarkable]) def groupedQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Benchmarkable]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Benchmarkable]]) @transient lazy val allQueries = (singleQueries ++ groupedQueries).toSeq def html: String = { val singleQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Query]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Query]) .mkString(",") val queries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Query]]) .map { method => val queries = runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Query]] val queryList = queries.map(_.name).mkString(", ") s""" |

${method.name}

| """.stripMargin }.mkString("\n") s""" |

Spark SQL Performance Benchmarking

|

Available Queries

|$singleQueries |$queries """.stripMargin } /** Factory object for benchmark queries. */ case object Query { def apply( name: String, sqlText: String, description: String, executionMode: ExecutionMode = ExecutionMode.ForeachResults): Query = { new Query(name, sqlContext.sql(sqlText), description, Some(sqlText), executionMode) } def apply( name: String, dataFrameBuilder: => DataFrame, description: String): Query = { new Query(name, dataFrameBuilder, description, None, ExecutionMode.CollectResults) } } object RDDCount { def apply( name: String, rdd: RDD[_]) = { new SparkPerfExecution( name, Map.empty, () => Unit, () => rdd.count(), rdd.toDebugString) } } /** A class for benchmarking Spark perf results. */ class SparkPerfExecution( override val name: String, parameters: Map[String, String], prepare: () => Unit, run: () => Unit, description: String = "") extends Benchmarkable { override def toString: String = s""" |== $name == |$description """.stripMargin protected override val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults protected override def beforeBenchmark(): Unit = { prepare() } protected override def doBenchmark( includeBreakdown: Boolean, description: String = "", messages: ArrayBuffer[String]): BenchmarkResult = { try { val timeMs = measureTimeMs(run()) BenchmarkResult( name = name, mode = executionMode.toString, parameters = parameters, executionTime = Some(timeMs)) } catch { case e: Exception => BenchmarkResult( name = name, mode = executionMode.toString, parameters = parameters, failure = Some(Failure(e.getClass.getSimpleName, e.getMessage))) } } } } /** * A Variation represents a setting (e.g. the number of shuffle partitions or if tables * are cached in memory) that we want to change in a experiment run. * A Variation has three parts, `name`, `options`, and `setup`. * The `name` is the identifier of a Variation. `options` is a Seq of options that * will be used for a query. Basically, a query will be executed with every option * defined in the list of `options`. `setup` defines the needed action for every * option. For example, the following Variation is used to change the number of shuffle * partitions of a query. The name of the Variation is "shufflePartitions". There are * two options, 200 and 2000. The setup is used to set the value of property * "spark.sql.shuffle.partitions". * * {{{ * Variation("shufflePartitions", Seq("200", "2000")) { * case num => sqlContext.setConf("spark.sql.shuffle.partitions", num) * } * }}} */ case class Variation[T](name: String, options: Seq[T])(val setup: T => Unit) case class Table( name: String, data: Dataset[_]) object Benchmark { class ExperimentStatus( executionsToRun: Seq[Benchmarkable], includeBreakdown: Boolean, iterations: Int, variations: Seq[Variation[_]], tags: Map[String, String], timeout: Long, resultsLocation: String, sqlContext: SQLContext, allTables: Seq[Table], currentConfiguration: BenchmarkConfiguration, forkThread: Boolean = true) { val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]() val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]() val currentMessages = new collection.mutable.ArrayBuffer[String]() def logMessage(msg: String) = { println(msg) currentMessages += msg } // Stats for HTML status message. @volatile var currentExecution = "" @volatile var currentPlan = "" // for queries only @volatile var currentConfig = "" @volatile var failures = 0 @volatile var startTime = 0L /** An optional log collection task that will run after the experiment. */ @volatile var logCollection: () => Unit = () => {} def cartesianProduct[T](xss: List[List[T]]): List[List[T]] = xss match { case Nil => List(Nil) case h :: t => for(xh <- h; xt <- cartesianProduct(t)) yield xh :: xt } val timestamp = System.currentTimeMillis() val resultPath = s"$resultsLocation/timestamp=$timestamp" val combinations = cartesianProduct(variations.map(l => (0 until l.options.size).toList).toList) val resultsFuture = Future { // If we're running queries, create tables for them executionsToRun .collect { case query: Query => query } .flatMap { query => try { query.newDataFrame().queryExecution.logical.collect { case r: UnresolvedRelation => r.tableName } } catch { // ignore the queries that can't be parsed case e: Exception => Seq() } } .distinct .foreach { name => try { sqlContext.table(name) logMessage(s"Table $name exists.") } catch { case ae: Exception => val table = allTables .find(_.name == name) if (table.isDefined) { logMessage(s"Creating table: $name") table.get.data .write .mode("overwrite") .saveAsTable(name) } else { // the table could be subquery logMessage(s"Couldn't read table $name and its not defined as a Benchmark.Table.") } } } // Run the benchmarks! val results: Seq[ExperimentRun] = (1 to iterations).flatMap { i => combinations.map { setup => val currentOptions = variations.asInstanceOf[Seq[Variation[Any]]].zip(setup).map { case (v, idx) => v.setup(v.options(idx)) v.name -> v.options(idx).toString } currentConfig = currentOptions.map { case (k,v) => s"$k: $v" }.mkString(", ") val res = executionsToRun.flatMap { q => val setup = s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v"}.mkString(", ")}" logMessage(s"Running execution ${q.name} $setup") currentExecution = q.name currentPlan = q match { case query: Query => try { query.newDataFrame().queryExecution.executedPlan.toString() } catch { case e: Exception => s"failed to parse: $e" } case _ => "" } startTime = System.currentTimeMillis() val singleResultT = Try { q.benchmark(includeBreakdown, setup, currentMessages, timeout, forkThread=forkThread) } singleResultT match { case Success(singleResult) => singleResult.failure.foreach { f => failures += 1 logMessage(s"Execution '${q.name}' failed: ${f.message}") } singleResult.executionTime.foreach { time => logMessage(s"Execution time: ${time / 1000}s") } currentResults += singleResult singleResult :: Nil case SFailure(e) => failures += 1 logMessage(s"Execution '${q.name}' failed: ${e}") Nil } } val result = ExperimentRun( timestamp = timestamp, iteration = i, tags = currentOptions.toMap ++ tags, configuration = currentConfiguration, res) currentRuns += result result } } try { val resultsTable = sqlContext.createDataFrame(results) logMessage(s"Results written to table: 'sqlPerformance' at $resultPath") resultsTable .coalesce(1) .write .format("json") .save(resultPath) } catch { case NonFatal(e) => logMessage(s"Failed to write data: $e") throw e } logCollection() } def scheduleCpuCollection(fs: FS) = { logCollection = () => { logMessage(s"Begining CPU log collection") try { val location = cpu.collectLogs(sqlContext, fs, timestamp) logMessage(s"cpu results recorded to $location") } catch { case NonFatal(e) => logMessage(s"Error collecting logs: $e") throw e } } } def cpuProfile = new Profile(sqlContext, sqlContext.read.json(getCpuLocation(timestamp))) def cpuProfileHtml(fs: FS) = { s""" |

CPU Profile

|Permalink: sqlContext.read.json("${getCpuLocation(timestamp)}")
|${cpuProfile.buildGraph(fs)} """.stripMargin } /** Waits for the finish of the experiment. */ def waitForFinish(timeoutInSeconds: Int) = { Await.result(resultsFuture, timeoutInSeconds.seconds) } /** Returns results from an actively running experiment. */ def getCurrentResults() = { val tbl = sqlContext.createDataFrame(currentResults) tbl.createOrReplaceTempView("currentResults") tbl } /** Returns full iterations from an actively running experiment. */ def getCurrentRuns() = { val tbl = sqlContext.createDataFrame(currentRuns) tbl.createOrReplaceTempView("currentRuns") tbl } def tail(n: Int = 20) = { currentMessages.takeRight(n).mkString("\n") } def status = if (resultsFuture.isCompleted) { if (resultsFuture.value.get.isFailure) "Failed" else "Successful" } else { "Running" } override def toString = s"""Permalink: table("sqlPerformance").where('timestamp === ${timestamp}L)""" def html: String = { val maybeQueryPlan: String = if (currentPlan.nonEmpty) { s""" |

QueryPlan

|
             |${currentPlan.replaceAll("\n", "
")} |
""".stripMargin } else { "" } s""" |

$status Experiment

|Permalink: sqlContext.read.json("$resultPath")
|Iterations complete: ${currentRuns.size / combinations.size} / $iterations
|Failures: $failures
|Executions run: ${currentResults.size} / ${iterations * combinations.size * executionsToRun.size} |
|Run time: ${(System.currentTimeMillis() - timestamp) / 1000}s
| |

Current Execution: $currentExecution

|Runtime: ${(System.currentTimeMillis() - startTime) / 1000}s
|$currentConfig
|$maybeQueryPlan |

Logs

|
         |${tail()}
         |
""".stripMargin } } }