package com.amazonaws.samples.heartfunction;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntime;
import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntimeClientBuilder;
import com.amazonaws.services.sagemakerruntime.model.InvokeEndpointRequest;
import com.amazonaws.services.sns.AmazonSNS;
import com.amazonaws.services.sns.AmazonSNSClientBuilder;
import com.amazonaws.services.sns.model.PublishRequest;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import org.apache.http.HttpStatus;
import org.json.JSONArray;
import org.json.JSONObject;

/* loaded from: input_file:com/amazonaws/samples/heartfunction/LambdaFunctionHandler.class */
public class LambdaFunctionHandler implements RequestHandler<APIGatewayProxyRequestEvent, ApiGatewayResponse> {
    private static String FEATURES = "features";
    private static String INSTANCES = "instances";
    private static String SAGEMAKER_ENDPOINT = System.getenv("SAGEMAKER_ENDPOINT");
    private static String TOPIC_ARN = System.getenv("TOPIC_ARN");
    private static int HEART_DISEASE_PREDICTION = 1;
    private int prediction_label = 0;
    private double score = 0.0d;
    private AmazonSageMakerRuntime sageMakerRuntime = AmazonSageMakerRuntimeClientBuilder.defaultClient();
    private AmazonSNS snsClient = AmazonSNSClientBuilder.defaultClient();

    @Override // com.amazonaws.services.lambda.runtime.RequestHandler
    public ApiGatewayResponse handleRequest(APIGatewayProxyRequestEvent aPIGatewayProxyRequestEvent, Context context) {
        if (aPIGatewayProxyRequestEvent != null) {
            context.getLogger().log("incoming event data " + aPIGatewayProxyRequestEvent.toString());
            List<Object> buildFeatures = buildFeatures(getEventData(aPIGatewayProxyRequestEvent, context));
            context.getLogger().log("features: " + buildFeatures.toString());
            JSONObject buildRequest = buildRequest(buildFeatures);
            context.getLogger().log("SageMaker request data : " + buildRequest.toString());
            JSONObject inference = getInference(buildRequest, context);
            if (inference != null) {
                context.getLogger().log("Inference response data : " + inference.toString());
                Iterator<Object> it = inference.getJSONArray("predictions").iterator();
                while (it.hasNext()) {
                    JSONObject jSONObject = (JSONObject) it.next();
                    this.prediction_label = ((Integer) jSONObject.get("predicted_label")).intValue();
                    this.score = ((Double) jSONObject.get("score")).doubleValue();
                    context.getLogger().log("recieved predition for heart disease with value of " + this.prediction_label);
                    context.getLogger().log("prediction confidence level is " + this.score);
                    if (this.prediction_label == HEART_DISEASE_PREDICTION) {
                        context.getLogger().log("Heart disases predicted with confidence score of " + this.score);
                        sendSNSMessage("We have prdicted that you may have a potential heart disease with confidence of " + this.score, context);
                    }
                }
            }
        }
        return ApiGatewayResponse.builder().setStatusCode(HttpStatus.SC_OK).setObjectBody("Prediction label is " + this.prediction_label + " with confidence of " + this.score).setHeaders(Collections.singletonMap("X-Powered-By", "AWS API Gateway & Lambda Serverless")).build();
    }

    private boolean sendSNSMessage(String str, Context context) {
        context.getLogger().log("published message with following ID " + this.snsClient.publish(new PublishRequest(TOPIC_ARN, str)).getMessageId());
        return true;
    }

    private JSONObject getInference(JSONObject jSONObject, Context context) {
        context.getLogger().log("Getting SageMaker inference to predict heart disease ");
        InvokeEndpointRequest invokeEndpointRequest = new InvokeEndpointRequest();
        invokeEndpointRequest.setContentType("application/json");
        try {
            invokeEndpointRequest.setBody(ByteBuffer.wrap(jSONObject.toString().getBytes("UTF-8")));
        } catch (UnsupportedEncodingException e) {
            context.getLogger().log("Unsuported sageMaker endpoint exception " + e.getMessage());
        }
        invokeEndpointRequest.setEndpointName(SAGEMAKER_ENDPOINT);
        return new JSONObject(StandardCharsets.UTF_8.decode(this.sageMakerRuntime.invokeEndpoint(invokeEndpointRequest).getBody()).toString());
    }

    private JSONObject buildRequest(List<Object> list) {
        if (list == null || list.isEmpty()) {
            return null;
        }
        JSONObject jSONObject = new JSONObject();
        JSONArray jSONArray = new JSONArray();
        JSONObject jSONObject2 = new JSONObject();
        jSONObject2.put(FEATURES, (Collection<?>) list);
        jSONObject.put(INSTANCES, jSONArray);
        jSONArray.put(jSONObject2);
        return jSONObject;
    }

    private JSONObject getEventData(APIGatewayProxyRequestEvent aPIGatewayProxyRequestEvent, Context context) {
        context.getLogger().log("Event body is " + aPIGatewayProxyRequestEvent.getBody());
        return new JSONObject(aPIGatewayProxyRequestEvent.getBody());
    }

    private List<Object> buildFeatures(JSONObject jSONObject) {
        Vector vector = new Vector();
        if (jSONObject != null) {
            vector.add(jSONObject.getString("age"));
            vector.add(Integer.valueOf(jSONObject.getInt("sex")));
            vector.add(Integer.valueOf(jSONObject.getInt("cp")));
            vector.add(Integer.valueOf(jSONObject.getInt("trestbps")));
            vector.add(Integer.valueOf(jSONObject.getInt("chol")));
            vector.add(Integer.valueOf(jSONObject.getInt("fbs")));
            vector.add(Integer.valueOf(jSONObject.getInt("restecg")));
            vector.add(Integer.valueOf(jSONObject.getInt("thalach")));
            vector.add(jSONObject.get("exang"));
            vector.add(jSONObject.get("oldpeak"));
            vector.add(jSONObject.get("slope"));
            vector.add(jSONObject.get("ca"));
            vector.add(jSONObject.get("thal"));
        }
        return vector;
    }
}
