/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package com.amplifyframework.predictions.aws.service

import android.graphics.RectF
import aws.sdk.kotlin.services.rekognition.RekognitionClient
import aws.sdk.kotlin.services.rekognition.detectFaces
import aws.sdk.kotlin.services.rekognition.detectLabels
import aws.sdk.kotlin.services.rekognition.detectModerationLabels
import aws.sdk.kotlin.services.rekognition.detectText
import aws.sdk.kotlin.services.rekognition.model.Attribute
import aws.sdk.kotlin.services.rekognition.model.Image
import aws.sdk.kotlin.services.rekognition.model.TextTypes
import aws.sdk.kotlin.services.rekognition.recognizeCelebrities
import aws.sdk.kotlin.services.rekognition.searchFacesByImage
import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider
import com.amplifyframework.core.Consumer
import com.amplifyframework.predictions.PredictionsException
import com.amplifyframework.predictions.aws.AWSPredictionsPluginConfiguration
import com.amplifyframework.predictions.aws.adapter.EmotionTypeAdapter
import com.amplifyframework.predictions.aws.adapter.GenderBinaryTypeAdapter
import com.amplifyframework.predictions.aws.adapter.RekognitionResultTransformers
import com.amplifyframework.predictions.models.Celebrity
import com.amplifyframework.predictions.models.CelebrityDetails
import com.amplifyframework.predictions.models.Emotion
import com.amplifyframework.predictions.models.EntityDetails
import com.amplifyframework.predictions.models.EntityMatch
import com.amplifyframework.predictions.models.Gender
import com.amplifyframework.predictions.models.IdentifiedText
import com.amplifyframework.predictions.models.Label
import com.amplifyframework.predictions.models.LabelType
import com.amplifyframework.predictions.result.IdentifyCelebritiesResult
import com.amplifyframework.predictions.result.IdentifyEntitiesResult
import com.amplifyframework.predictions.result.IdentifyEntityMatchesResult
import com.amplifyframework.predictions.result.IdentifyLabelsResult
import com.amplifyframework.predictions.result.IdentifyResult
import com.amplifyframework.predictions.result.IdentifyTextResult
import java.lang.StringBuilder
import java.net.MalformedURLException
import java.net.URL
import java.nio.ByteBuffer
import java.util.Collections
import java.util.concurrent.Executors
import kotlinx.coroutines.runBlocking

/**
 * Predictions service for performing image analysis.
 */
internal class AWSRekognitionService(
    private val pluginConfiguration: AWSPredictionsPluginConfiguration,
    private val authCredentialsProvider: CredentialsProvider
) {

    val client: RekognitionClient = RekognitionClient {
        this.region = pluginConfiguration.defaultRegion
        this.credentialsProvider = authCredentialsProvider
    }

    private val executor = Executors.newCachedThreadPool()

    fun detectLabels(
        type: LabelType,
        imageData: ByteBuffer,
        onSuccess: Consumer<IdentifyResult>,
        onError: Consumer<PredictionsException>
    ) {
        execute(
            {
                val labels: MutableList<Label> = ArrayList()
                var unsafeContent = false
                // Moderation labels detection
                if (LabelType.ALL == type || LabelType.MODERATION_LABELS == type) {
                    labels.addAll(detectModerationLabels(imageData))
                    unsafeContent = labels.isNotEmpty()
                }
                // Regular labels detection
                if (LabelType.ALL == type || LabelType.LABELS == type) {
                    labels.addAll(detectLabels(imageData))
                }
                IdentifyLabelsResult.builder()
                    .labels(labels)
                    .unsafeContent(unsafeContent)
                    .build()
            },
            { throwable ->
                PredictionsException(
                    "Amazon Rekognition encountered an error while detecting labels.",
                    throwable,
                    "See attached exception for more details."
                )
            },
            onSuccess,
            onError
        )
    }

    fun recognizeCelebrities(
        imageData: ByteBuffer,
        onSuccess: Consumer<IdentifyResult>,
        onError: Consumer<PredictionsException>
    ) {
        val config = pluginConfiguration.identifyEntitiesConfiguration
        if (!config.isCelebrityDetectionEnabled) {
            onError.accept(
                PredictionsException(
                    "Celebrity detection is disabled.",
                    "Please enable celebrity detection via Amplify CLI. This feature should be accessible by " +
                        "running `amplify update predictions` in the console and updating entities " +
                        "detection resource with advanced configuration setting."
                )
            )
            return
        }

        execute(
            {
                val celebrities = detectCelebrities(imageData)
                IdentifyCelebritiesResult.fromCelebrities(celebrities)
            },
            { throwable ->
                PredictionsException(
                    "Amazon Rekognition encountered an error while recognizing celebrities.",
                    throwable,
                    "See attached exception for more details."
                )
            },
            onSuccess,
            onError
        )
    }

    fun detectEntities(
        imageData: ByteBuffer,
        onSuccess: Consumer<IdentifyResult>,
        onError: Consumer<PredictionsException>
    ) {
        execute(
            {
                val config = pluginConfiguration.identifyEntitiesConfiguration
                if (config.isGeneralEntityDetection) {
                    val entities = detectEntities(imageData)
                    IdentifyEntitiesResult.fromEntityDetails(entities)
                } else {
                    val maxEntities = config.maxEntities
                    val collectionId = config.collectionId
                    val matches = detectEntityMatches(imageData, maxEntities, collectionId)
                    IdentifyEntityMatchesResult.fromEntityMatches(matches)
                }
            },
            { throwable ->
                PredictionsException(
                    "Amazon Rekognition encountered an error while either detecting or searching for faces.",
                    throwable,
                    "See attached exception for more details."
                )
            },
            onSuccess,
            onError
        )
    }

    fun detectPlainText(
        imageData: ByteBuffer,
        onSuccess: Consumer<IdentifyResult>,
        onError: Consumer<PredictionsException>
    ) {
        execute(
            {
                detectPlainText(imageData)
            },
            { throwable ->
                PredictionsException(
                    "Amazon Rekognition encountered an error while detecting text.",
                    throwable,
                    "See attached exception for more details."
                )
            },
            onSuccess,
            onError
        )
    }

    @Throws(PredictionsException::class)
    private suspend fun detectLabels(imageData: ByteBuffer): List<Label> {
        // Detect labels in the given image via Amazon Rekognition
        val result = try {
            client.detectLabels {
                this.image = Image {
                    this.bytes = imageData.array()
                }
            }
        } catch (exception: Exception) {
            throw PredictionsException(
                "Amazon Rekognition encountered an error while detecting labels.",
                exception,
                "See attached exception for more details."
            )
        }
        val labels: MutableList<Label> = ArrayList()
        result.labels?.forEach { rekognitionLabel ->
            val parents: MutableList<String> = ArrayList()
            rekognitionLabel.parents?.forEach { parent ->
                parent.name?.let { parentName -> parents.add(parentName) }
            }
            val boxes: MutableList<RectF?> = ArrayList()
            rekognitionLabel.instances?.forEach { instance ->
                boxes.add(RekognitionResultTransformers.fromBoundingBox(instance.boundingBox))
            }
            rekognitionLabel.name?.let { labelName ->
                rekognitionLabel.confidence?.let { labelConfidence ->
                    val amplifyLabel = Label.builder()
                        .value(labelName)
                        .confidence(labelConfidence)
                        .parentLabels(parents)
                        .boxes(boxes)
                        .build()
                    labels.add(amplifyLabel)
                }
            }
        }
        return labels
    }

    @Throws(PredictionsException::class)
    private suspend fun detectModerationLabels(imageData: ByteBuffer): List<Label> {
        val result = try {
            client.detectModerationLabels {
                this.image = Image {
                    this.bytes = imageData.array()
                }
            }
        } catch (exception: Exception) {
            throw PredictionsException(
                "Amazon Rekognition encountered an error while detecting moderation labels.",
                exception,
                "See attached exception for more details."
            )
        }
        val labels: MutableList<Label> = ArrayList()
        result.moderationLabels?.forEach { moderationLabel ->
            moderationLabel.name?.let { labelName ->
                moderationLabel.confidence?.let { labelConfidence ->
                    val label = Label.builder()
                        .value(labelName)
                        .confidence(labelConfidence)
                        .parentLabels(listOf(moderationLabel.parentName))
                        .build()
                    labels.add(label)
                }
            }
        }
        return labels
    }

    private suspend fun detectCelebrities(imageData: ByteBuffer): List<CelebrityDetails> {
        val result = client.recognizeCelebrities {
            this.image = Image {
                this.bytes = imageData.array()
            }
        }
        val celebrities: MutableList<CelebrityDetails> = ArrayList()
        result.celebrityFaces?.forEach { rekognitionCelebrity ->
            val amplifyCelebrity = rekognitionCelebrity.id?.let { celebrityId ->
                rekognitionCelebrity.name?.let { celebrityName ->
                    rekognitionCelebrity.matchConfidence?.let { celebrityMatchConfidence ->
                        Celebrity.builder()
                            .id(celebrityId)
                            .value(celebrityName)
                            .confidence(celebrityMatchConfidence)
                            .build()
                    }
                }
            }
            // Get face-specific celebrity details from the result
            val face = rekognitionCelebrity.face
            val box = RekognitionResultTransformers.fromBoundingBox(face?.boundingBox)
            val pose = RekognitionResultTransformers.fromRekognitionPose(face?.pose)
            val landmarks = RekognitionResultTransformers.fromLandmarks(face?.landmarks)

            // Get URL links that are relevant to celebrities
            val urls: MutableList<URL> = ArrayList()
            rekognitionCelebrity.urls?.forEach { url ->
                try {
                    urls.add(URL(url))
                } catch (badUrl: MalformedURLException) {
                    // Ignore bad URL
                }
            }
            amplifyCelebrity?.let {
                val details = CelebrityDetails.builder()
                    .celebrity(it)
                    .box(box)
                    .pose(pose)
                    .landmarks(landmarks)
                    .urls(urls)
                    .build()
                celebrities.add(details)
            }
        }
        return celebrities
    }

    @Throws(PredictionsException::class)
    private suspend fun detectEntities(imageData: ByteBuffer): List<EntityDetails> {
        val result = try {
            client.detectFaces {
                this.image = Image {
                    this.bytes = imageData.array()
                }
                this.attributes = mutableListOf(Attribute.All)
            }
        } catch (exception: Exception) {
            throw PredictionsException(
                "Amazon Rekognition encountered an error while detecting faces.",
                exception,
                "See attached exception for more details."
            )
        }
        val entities: MutableList<EntityDetails> = ArrayList()
        result.faceDetails?.forEach { face ->
            // Extract details from face detection
            val box = RekognitionResultTransformers.fromBoundingBox(face.boundingBox)
            val ageRange = RekognitionResultTransformers.fromRekognitionAgeRange(face.ageRange)
            val pose = RekognitionResultTransformers.fromRekognitionPose(face.pose)
            val landmarks = RekognitionResultTransformers.fromLandmarks(face.landmarks)
            val features = RekognitionResultTransformers.fromFaceDetail(face)

            // Gender detection
            val amplifyGender = face.gender?.let { faceGender ->
                faceGender.confidence?.let { faceGenderConfidence ->
                    Gender.builder()
                        .value(GenderBinaryTypeAdapter.fromRekognition(faceGender.value?.value ?: ""))
                        .confidence(faceGenderConfidence)
                        .build()
                }
            }

            // Emotion detection
            val emotions: MutableList<Emotion> = ArrayList()
            face.emotions?.forEach { rekognitionEmotion ->
                val emotion = EmotionTypeAdapter.fromRekognition(rekognitionEmotion.type?.value ?: "")
                rekognitionEmotion.confidence?.let { emotionConfidence ->
                    val amplifyEmotion = Emotion.builder()
                        .value(emotion)
                        .confidence(emotionConfidence)
                        .build()
                    emotions.add(amplifyEmotion)
                }
            }
            Collections.sort(emotions, Collections.reverseOrder())
            val entity = EntityDetails.builder()
                .box(box)
                .ageRange(ageRange)
                .pose(pose)
                .gender(amplifyGender)
                .landmarks(landmarks)
                .emotions(emotions)
                .features(features)
                .build()
            entities.add(entity)
        }
        return entities
    }

    @Throws(PredictionsException::class)
    private suspend fun detectEntityMatches(
        imageData: ByteBuffer,
        maxEntities: Int,
        collectionId: String
    ): List<EntityMatch> {

        val result = try {
            client.searchFacesByImage {
                this.image = Image {
                    this.bytes = imageData.array()
                }
                this.maxFaces = maxEntities
                this.collectionId = collectionId
            }
        } catch (exception: Exception) {
            throw PredictionsException(
                "Amazon Rekognition encountered an error while searching for known faces.",
                exception,
                "See attached exception for more details."
            )
        }
        val matches: MutableList<EntityMatch> = ArrayList()
        result.faceMatches?.forEach { rekognitionMatch ->
            val box = RekognitionResultTransformers.fromBoundingBox(rekognitionMatch.face?.boundingBox)
            rekognitionMatch.face?.externalImageId?.let { faceImageId ->
                rekognitionMatch.similarity?.let { matchSimilarity ->
                    val amplifyMatch = EntityMatch.builder()
                        .externalImageId(faceImageId)
                        .confidence(matchSimilarity)
                        .box(box)
                        .build()
                    matches.add(amplifyMatch)
                }
            }
        }
        return matches
    }

    private suspend fun detectPlainText(imageData: ByteBuffer): IdentifyTextResult {
        val result = client.detectText {
            this.image = Image {
                this.bytes = imageData.array()
            }
        }
        val fullTextBuilder = StringBuilder()
        val rawLineText: MutableList<String> = ArrayList()
        val words: MutableList<IdentifiedText?> = ArrayList()
        val lines: MutableList<IdentifiedText?> = ArrayList()
        result.textDetections?.forEach { detection ->
            when (TextTypes.fromValue(detection.type?.value ?: "")) {
                TextTypes.Line -> {
                    detection.detectedText?.let { rawLineText.add(it) }
                    lines.add(RekognitionResultTransformers.fromTextDetection(detection))
                }
                TextTypes.Word -> {
                    fullTextBuilder.append(detection.detectedText).append(" ")
                    words.add(RekognitionResultTransformers.fromTextDetection(detection))
                }
                else -> { }
            }
        }
        return IdentifyTextResult.builder()
            .fullText(fullTextBuilder.toString().trim { it <= ' ' })
            .rawLineText(rawLineText)
            .lines(lines)
            .words(words)
            .build()
    }

    private fun <T : Any> execute(
        runnableTask: suspend () -> T,
        errorTransformer: (Throwable) -> PredictionsException,
        onResult: Consumer<T>,
        onError: Consumer<PredictionsException>
    ) {
        executor.execute {
            try {
                runBlocking {
                    val result = runnableTask()
                    onResult.accept(result)
                }
            } catch (error: Throwable) {
                val predictionsException = if (error is PredictionsException) {
                    error
                } else {
                    errorTransformer.invoke(error)
                }
                onError.accept(predictionsException)
            }
        }
    }
}