/*
 * Copyright 2016-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: MIT-0
 */
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.util
import java.io.IOException
import collection.JavaConverters._

class DefaultSource extends TableProvider {
  override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType =
    getTable(
      null,
      Array.empty[Transform],
      caseInsensitiveStringMap.asCaseSensitiveMap()).schema()

  override def getTable(structType: StructType, transforms: Array[Transform], javaProps: util.Map[String, String]): Table = {
    // If you handle special options passed from connection options,
    // you can do it through processing `javaProps`.
    new SimpleTable
  }
}

class SimpleTable extends Table with SupportsRead with SupportsWrite {
  override def name(): String = this.getClass.toString // Table name from specified option

  override def schema(): StructType =
    new StructType()
      .add("id", "int")
      .add("value", "int")

  override def capabilities(): util.Set[TableCapability] = Set(
    TableCapability.BATCH_READ,
    TableCapability.BATCH_WRITE).asJava // Supports BATCH read/write mode

  // Read
  override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): ScanBuilder =
    new SimpleScanBuilder(this.schema())

  // Write
  override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder =
    new SimpleWriteBuilder
}


/* Read */
class SimpleScanBuilder(schema: StructType) extends ScanBuilder {
  override def build(): Scan = new SimpleScan(schema)
}

class SimpleScan(schema: StructType) extends Scan {
  override def readSchema(): StructType = schema

  override def toBatch: Batch = new SimpleBatch()
}


class SimpleBatch extends Batch {
  override def planInputPartitions(): Array[InputPartition] = {
    Array(
      new SimpleInputPartition(0, 5),
      new SimpleInputPartition(5, 10)
    )
  }

  override def createReaderFactory(): PartitionReaderFactory = new SimpleReaderFactory
}

class SimpleInputPartition(var start: Int, var end: Int) extends InputPartition {
  override def preferredLocations(): Array[String] = super.preferredLocations()
}

class SimpleReaderFactory extends PartitionReaderFactory {
  override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] =
    new SimplePartitionReader(inputPartition.asInstanceOf[SimpleInputPartition])
}

class SimplePartitionReader(val simpleInputPartition: SimpleInputPartition) extends PartitionReader[InternalRow] {
  var start: Int = simpleInputPartition.start
  var end: Int = simpleInputPartition.end
  override def next(): Boolean = {
    start = start + 1
    start < end
  }

  override def get(): InternalRow = {
    val row = Array(start, -start)
    InternalRow.fromSeq(row.toSeq)
  }

  @throws[IOException]
  override def close(): Unit = {}
}

/* Write */
class SimpleWriteBuilder extends WriteBuilder {
  override def buildForBatch(): BatchWrite =
    new SimpleBatchWrite
}

class SimpleBatchWrite extends BatchWrite {
  override def abort(writerCommitMessages: Array[WriterCommitMessage]): Unit = {}
  override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {}
  override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory =
    new SimpleDataWriterFactory
}

class SimpleDataWriterFactory extends DataWriterFactory {
  override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] =
    new SimpleDataWriter(partitionId, taskId)
}

class SimpleDataWriter(partitionId: Int, taskId: Long) extends DataWriter[InternalRow] {
  @throws[IOException]
  override def abort(): Unit = {}

  @throws[IOException]
  override def commit(): WriterCommitMessage = null

  @throws[IOException]
  override def write(record: InternalRow): Unit = {
    // In this sample code, this part simply prints records for testing-purpose.
    println(s"write a record with id : ${record.getInt(0)} and value: ${record.getInt(1)} for partitionId: $partitionId by taskId: $taskId")
  }

  @throws[IOException]
  override def close(): Unit = {}
}