// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT-0

package org.custom.connector.jdbc.client;

import com.amazonaws.appflow.custom.connector.model.metadata.DescribeEntityRequest;
import com.amazonaws.appflow.custom.connector.model.metadata.Entity;
import com.amazonaws.appflow.custom.connector.model.metadata.FieldDataType;
import com.amazonaws.appflow.custom.connector.model.metadata.FieldDefinition;
import com.amazonaws.appflow.custom.connector.model.metadata.ImmutableEntity;
import com.amazonaws.appflow.custom.connector.model.metadata.ImmutableFieldDefinition;
import com.amazonaws.appflow.custom.connector.model.metadata.ImmutableReadOperationProperty;
import com.amazonaws.appflow.custom.connector.model.metadata.ImmutableWriteOperationProperty;
import com.amazonaws.appflow.custom.connector.model.metadata.ListEntitiesRequest;
import com.amazonaws.appflow.custom.connector.model.query.QueryDataRequest;
import com.amazonaws.appflow.custom.connector.model.write.WriteDataRequest;
import com.amazonaws.appflow.custom.connector.model.write.WriteOperationType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public final class MySQLClient implements JDBCClient {
  private static final Logger LOGGER = LoggerFactory.getLogger(MySQLClient.class);
  final ObjectMapper objectMapper = new ObjectMapper();
  private Connection conn = null;
  private final Map<String, String> credentials;

  public MySQLClient(final Map<String, String> creds) {
    credentials = creds;
  }

  @Override
  public List<WriteOperationType> getWriteOperations() {
    List<WriteOperationType> writeOperationTypes = new ArrayList<>();
    writeOperationTypes.add(WriteOperationType.UPSERT);
    writeOperationTypes.add(WriteOperationType.UPDATE);
    writeOperationTypes.add(WriteOperationType.INSERT);
    return writeOperationTypes;
  }

  @Override
  public List<Entity> getEntities(final ListEntitiesRequest request) throws SQLException {
    final List<Entity> records = new ArrayList<Entity>();

    Connection conn = getConnection();

    DatabaseMetaData metaData = conn.getMetaData();
    String[] types = {"TABLE"};
    // Retrieving the columns in the database
    ResultSet tables = metaData.getTables(null, null, "%", types);
    while (tables.next()) {
      records.add(
        ImmutableEntity.builder()
          .entityIdentifier(tables.getString("TABLE_NAME"))
          .description(tables.getString("TABLE_NAME"))
          .label(tables.getString("TABLE_NAME"))
          .hasNestedEntities(false)
          .build());
    }
    conn.close();
    return records;
  }

  @Override
  public List<FieldDefinition> getFieldDefinitions(final DescribeEntityRequest request) throws SQLException {
    final List<FieldDefinition> fieldDefinitions = new ArrayList<>();
    Connection conn = getConnection();

    Statement st = conn.createStatement();
    String sql = String.format("DESCRIBE `%s`", request.entityIdentifier());
    ResultSet rs = st.executeQuery(sql);

    while (rs.next()) {
      fieldDefinitions.add(ImmutableFieldDefinition.builder()
        .fieldName(rs.getString(1))
        .dataType(mapFieldType(rs.getString(2)))
        .dataTypeLabel(rs.getString(1))
        .label(rs.getString(1))
        .isPrimaryKey(rs.getString(4).equals("PRI"))
        .readProperties(ImmutableReadOperationProperty.builder()
          .isQueryable(true)
          .isRetrievable(true)
          .build())
        .writeProperties(ImmutableWriteOperationProperty.builder()
          .isNullable(true)
          .isUpdatable(true)
          .isCreatable(true)
          .isDefaultedOnCreate(!rs.getString(4).equals("PRI"))
          .supportedWriteOperations(getWriteOperations())
          .build())
        .build());
    }
    rs.close();
    conn.close();
    return fieldDefinitions;
  }

  @Override
  public Connection getConnection() {
    try {
      if (conn != null && conn.isValid(0)) {
        return conn;
      }
    } catch (SQLException e) {
      // Do nothing for now.
    }

    try {
      String uri = String.format(
        "jdbc:%s://%s:%s/%s?user=%s&password=%s",
        credentials.get("driver"),
        credentials.get("hostname"),
        credentials.get("port"),
        credentials.get("database"),
        credentials.get("username"),
        credentials.get("password")
      );

      conn = DriverManager.getConnection(uri);
    } catch (SQLException ex) {
      // handle any errors
      LOGGER.error("SQLException: " + ex.getMessage());
      LOGGER.error("SQLState: " + ex.getSQLState());
      LOGGER.error("VendorError: " + ex.getErrorCode());
    }
    return conn;
  }

  private FieldDataType mapFieldType(final String mysqlType) {
    String[] temp = mysqlType.split("\\(");
    String mtype = temp[0].toUpperCase();
    switch (mtype) {
      case "ARRAY":
      case "STRUCT":
        return FieldDataType.Struct;
      case "BIGINT":
      case "NUMERIC":
        return FieldDataType.BigInteger;
      case "BINARY":
      case "LONGVARBINARY":
      case "VARBINARY":
        return FieldDataType.ByteArray;
      case "BIT":
      case "SMALLINT":
      case "INTEGER":
      case "TINYINT":
      case "MEDIUMINT":
      case "INT":
        return FieldDataType.Integer;
      case "BLOB":
        return FieldDataType.String;
      case "BOOLEAN":
        return FieldDataType.Boolean;
      case "CHAR":
        return FieldDataType.String;
      case "CLOB":
        return FieldDataType.String;
      case "DATALINK":
        return FieldDataType.String;
      case "DATE":
        return FieldDataType.Date;
      case "DECIMAL":
      case "DOUBLE":
        return FieldDataType.Double;
      case "DISTINCT":
        return FieldDataType.String;
      case "FLOAT":
        return FieldDataType.Float;
      case "JAVA_OBJECT":
        return FieldDataType.Map;
      case "LONGNVARCHAR":
        return FieldDataType.String;
      case "LONGVARCHAR":
        return FieldDataType.String;
      case "NCHAR":
        return FieldDataType.String;
      case "NCLOB":
        return FieldDataType.String;
      case "NULL":
        return FieldDataType.String;
      case "NVARCHAR":
        return FieldDataType.String;
      case "OTHER":
        return FieldDataType.String;
      case "REAL":
        return FieldDataType.Long;
      case "REF":
        return FieldDataType.String;
      case "REF_CURSOR":
        return FieldDataType.String;
      case "ROWID":
        return FieldDataType.String;
      case "SQLXML":
        return FieldDataType.String;
      case "TIME":
      case "TIME_WITH_TIMEZONE":
      case "TIMESTAMP":
      case "TIMESTAMP_WITH_TIMEZONE":
        return FieldDataType.DateTime;
      case "VARCHAR":
        return FieldDataType.String;
      default:
        return FieldDataType.String;
    }
  }

  @Override
  public long getTotalData(final QueryDataRequest request) {
    try (Connection conn = getConnection()) {
      Statement st = conn.createStatement();

      String sql = String.format("SELECT COUNT(*) as cnt FROM `%s`", request.entityIdentifier());
      if (request.filterExpression() != null) {
        sql = sql + String.format(" WHERE %s", request.filterExpression());
      }

      ResultSet rs = st.executeQuery(sql);
      rs.next();
      long count = rs.getLong("cnt");
      st.close();
      conn.close();
      return count;
    } catch (SQLException ex) {
      LOGGER.error("SQLException information");
      while (ex != null) {
        LOGGER.error("Error msg: " + ex.getMessage());
        ex = ex.getNextException();
      }
      throw new RuntimeException("Error");
    }
  }

  @Override
  public List<String> queryData(final QueryDataRequest request) {
    List<String> records = new ArrayList<String>();

    try (Connection conn = getConnection()) {

      Statement st = conn.createStatement();
      String sql = String.format(
        "SELECT %s FROM `%s`", String.join(",", request.selectedFieldNames()), request.entityIdentifier()
      );

      if (request.filterExpression() != null) {
        sql = sql + String.format(" WHERE %s", request.filterExpression());
      }

      if (request.maxResults() != null) {
        int nextToken = 0;
        if (request.nextToken() != null) {
          nextToken = Integer.parseInt(request.nextToken());
        }
        sql = sql + String.format(" LIMIT %s, %s", nextToken, request.maxResults());
      }

      ResultSet rs = st.executeQuery(sql);
      Map<String, String> rows = new HashMap<>();

      while (rs.next()) {
        for (int i = 0; i < request.selectedFieldNames().size(); i++) {
          rows.put(request.selectedFieldNames().get(i), rs.getString(i + 1));
        }
        try {
          records.add(objectMapper.writeValueAsString(rows));
          rows.clear();
        } catch (JsonProcessingException e) {
          e.printStackTrace();
        }
      }
      rs.close();
    } catch (SQLException ex) {
      LOGGER.error("SQLException information");
      while (ex != null) {
        LOGGER.error("Error msg: " + ex.getMessage());
        ex = ex.getNextException();
      }
      throw new RuntimeException("Error");
    }
    return records;
  }

  @Override
  public int[] writeData(final WriteDataRequest request) {
    JsonNode recordJson;

    try (Connection conn = getConnection()) {
      conn.setAutoCommit(true);
      Statement statement = conn.createStatement();

      String sql;

      for (String record : request.records()) {
        sql = "";
        try {
          recordJson = objectMapper.readValue(record, JsonNode.class);
        } catch (JsonProcessingException e) {
          throw new IllegalArgumentException("Invalid record provided for Write operation. Record must be valid JSON", e);
        }
        List<String> keys = new ArrayList<>();
        Iterator<String> iterator = recordJson.fieldNames();
        iterator.forEachRemaining(e -> keys.add(e));

        if (WriteOperationType.INSERT.equals(request.operation()) || WriteOperationType.UPSERT.equals(request.operation())) {
          if (WriteOperationType.UPSERT.equals(request.operation())) {
            sql = "REPLACE";
          } else {
            sql = "INSERT";
          }

          sql += String.format(" INTO `%s` (%s) VALUES (", request.entityIdentifier(), String.join(",", keys));
          String value;

          for (int i = 0; i < keys.size(); i++) {
            value = StringEscapeUtils.escapeJava(getValueFromRecord(recordJson, keys.get(i)));
            if (i > 0) {
              sql += String.format(", \"%s\"", value);
            } else {
              sql += String.format("\"%s\"", value);
            }
          }
          sql += ")";
        } else if (WriteOperationType.UPDATE.equals(request.operation())) {
          if (Objects.requireNonNull(request.idFieldNames()).size() != 1) {
            throw new IllegalArgumentException("A single Id field is required for UPSERT operations in JDBC");
          }

          String recordIdKey = request.idFieldNames().get(0);
          String recordId = getValueFromRecord(recordJson, recordIdKey);
          String value;
          sql = String.format("UPDATE `%s` SET ", request.entityIdentifier());
          for (int i = 0; i < keys.size(); i++) {
            value = StringEscapeUtils.escapeJava(getValueFromRecord(recordJson, keys.get(i)));

            if (i > 0) {
              sql += String.format(", %s = \"%s\"", keys.get(i), value);
            } else {
              sql += String.format("%s = \"%s\"", keys.get(i), value);
            }
          }
          sql += String.format(" WHERE %s = %s", recordIdKey, recordId);
        }
        statement.addBatch(sql);
      }
      int[] records = statement.executeBatch();
      statement.close();
      conn.close();
      return records;
    } catch (SQLException ex) {
      LOGGER.error("SQLException information");
      while (ex != null) {
        LOGGER.error("Error msg: " + ex.getMessage());
        ex = ex.getNextException();
      }
      throw new RuntimeException("Error");
    }
  }

  private String getValueFromRecord(final JsonNode jsonRecord, final String key) {
    if (Objects.isNull(jsonRecord) || Objects.isNull(jsonRecord.get(key))) {
      throw new IllegalArgumentException(key + " key is missing from JSON record but is required");
    }
    if (StringUtils.isEmpty(jsonRecord.get(key).textValue())) {
      return null;
    }
    return jsonRecord.get(key).textValue();
  }
}