package com.databricks.spark.sql.perf.mllib.feature import org.apache.spark.ml import org.apache.spark.ml.PipelineStage import org.apache.spark.sql._ import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator /** Object for testing BucketedRandomProjectionLSH performance */ object BucketedRandomProjectionLSH extends BenchmarkAlgorithm with TestFromTraining { override def trainingDataSet(ctx: MLBenchContext): DataFrame = { import ctx.params._ val df = DataGenerator.generateContinuousFeatures( ctx.sqlContext, numExamples, ctx.seed(), numPartitions, numFeatures ) df } override def getPipelineStage(ctx: MLBenchContext): PipelineStage = { import ctx.params._ new ml.feature.BucketedRandomProjectionLSH() .setInputCol("features") .setNumHashTables(numHashTables) } }