package com.amazonaws.kvsmlinference; import com.amazonaws.SdkClientException; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.kinesisvideo.parser.ebml.InputStreamParserByteSource; import com.amazonaws.kinesisvideo.parser.mkv.StreamingMkvReader; import com.amazonaws.kinesisvideo.parser.utilities.FragmentMetadataVisitor; import com.amazonaws.regions.Regions; import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestHandler; import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntime; import com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntimeClientBuilder; import com.amazonaws.services.sagemakerruntime.model.InvokeEndpointRequest; import com.amazonaws.services.sagemakerruntime.model.InvokeEndpointResult; import com.amazonaws.services.sagemakerruntime.model.AmazonSageMakerRuntimeException; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; import com.amazonaws.services.dynamodbv2.document.DynamoDB; import com.amazonaws.services.dynamodbv2.document.Item; import com.amazonaws.services.dynamodbv2.document.PutItemOutcome; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import java.io.*; import java.nio.ByteBuffer; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.charset.StandardCharsets; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.time.Instant; /** * Demonstrate Amazon Connect's real-time transcription feature using AWS Kinesis Video Streams and AWS Transcribe. * *

Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.

* * Permission is hereby granted, free of charge, to any person obtaining a copy of this * software and associated documentation files (the "Software"), to deal in the Software * without restriction, including without limitation the rights to use, copy, modify, * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so. *

* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ public class KVSMLInferenceLambda implements RequestHandler { private static final Regions REGION = Regions.fromName(System.getenv("APP_REGION")); private static final String RECORDINGS_BUCKET_NAME = System.getenv("RECORDINGS_BUCKET_NAME"); private static final String RECORDINGS_KEY_PREFIX = System.getenv("RECORDINGS_KEY_PREFIX"); private static final boolean CONSOLE_LOG_TRANSCRIPT_FLAG = Boolean.parseBoolean(System.getenv("CONSOLE_LOG_TRANSCRIPT_FLAG")); private static final boolean RECORDINGS_PUBLIC_READ_ACL = Boolean.parseBoolean(System.getenv("RECORDINGS_PUBLIC_READ_ACL")); private static final String START_SELECTOR_TYPE = System.getenv("START_SELECTOR_TYPE"); private static final String SM_ENDPOINT_NAME = System.getenv("SM_ENDPOINT_NAME"); private static final String TABLE_ML_INFERENCE = System.getenv("TABLE_ML_INFERENCE"); private static final Logger logger = LoggerFactory.getLogger(KVSMLInferenceLambda.class); public static final MetricsUtil metricsUtil = new MetricsUtil(AmazonCloudWatchClientBuilder.defaultClient()); private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ"); /** * Handler function for the Lambda * * @param request * @param context * @return */ @Override public String handleRequest(AudioPredictionRequest request, Context context) { logger.info("received request : " + request.toString()); logger.info("received context: " + context.toString()); try { // validate the request request.validate(); startKVSToPredictionStreaming(request.getStreamARN(), request.getStartFragmentNum(), request.getConnectContactId(), request.getSaveCallRecording()); return "{ \"result\": \"Success\" }"; } catch (Exception e) { logger.error("KVS to Transcribe Streaming failed with: ", e); return "{ \"result\": \"Failed\" }"; } } /** * Starts streaming between KVS and Transcribe * The transcript segments are continuously saved to the Dynamo DB table * At end of the streaming session, the raw audio is saved as an s3 object * * @param streamARN * @param startFragmentNum * @param contactId * @param saveCallRecording * @throws Exception */ private void startKVSToPredictionStreaming(String streamARN, String startFragmentNum, String contactId, Optional saveCallRecording) throws Exception { String streamName = streamARN.substring(streamARN.indexOf("/") + 1, streamARN.lastIndexOf("/")); KVSStreamTrackObject kvsStreamTrackObjectFromCustomer = getKVSStreamTrackObject(streamName, startFragmentNum, KVSUtils.TrackName.AUDIO_FROM_CUSTOMER.getName(), contactId); logger.info("Start to process KVS streaming and make prediction."); if (kvsStreamTrackObjectFromCustomer != null) { // get audio streaming from KVS to local file ByteBuffer audioBuffer = KVSUtils.getByteBufferFromStream(kvsStreamTrackObjectFromCustomer.getStreamingMkvReader(), kvsStreamTrackObjectFromCustomer.getFragmentVisitor(), kvsStreamTrackObjectFromCustomer.getTagProcessor(), contactId, kvsStreamTrackObjectFromCustomer.getTrackName()); while (audioBuffer.remaining() > 0) { byte[] audioBytes = new byte[audioBuffer.remaining()]; audioBuffer.get(audioBytes); kvsStreamTrackObjectFromCustomer.getOutputStream().write(audioBytes); audioBuffer = KVSUtils.getByteBufferFromStream(kvsStreamTrackObjectFromCustomer.getStreamingMkvReader(), kvsStreamTrackObjectFromCustomer.getFragmentVisitor(), kvsStreamTrackObjectFromCustomer.getTagProcessor(), contactId, kvsStreamTrackObjectFromCustomer.getTrackName()); } String audioFilePath = kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString(); File audioFile = new File(audioFilePath); logger.info("file path: "+audioFilePath); logger.info("file size: "+audioFile.length()); //Upload the Raw Audio file to S3 kvsStreamTrackObjectFromCustomer.getInputStream().close(); kvsStreamTrackObjectFromCustomer.getOutputStream().close(); if (audioFile.length() > 0) { String s3path = AudioUtils.uploadRawAudio(REGION, RECORDINGS_BUCKET_NAME, RECORDINGS_KEY_PREFIX, kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString(), contactId, RECORDINGS_PUBLIC_READ_ACL, getAWSCredentials()); if (s3path.length()>1) { logger.info("Audio file uploaded successfully to: " + s3path); try { //Invoke SageMaker Inference endpoint AmazonSageMakerRuntime smclient = AmazonSageMakerRuntimeClientBuilder .standard() .withRegion(REGION) .withCredentials(getAWSCredentials()) .build(); InvokeEndpointRequest invokeEndpointRequest = new InvokeEndpointRequest(); invokeEndpointRequest.setContentType("text/csv"); invokeEndpointRequest.setEndpointName(SM_ENDPOINT_NAME); invokeEndpointRequest.setBody(ByteBuffer.wrap(s3path.getBytes("UTF-8"))); InvokeEndpointResult result = smclient.invokeEndpoint(invokeEndpointRequest); String body = StandardCharsets.UTF_8.decode(result.getBody()).toString(); logger.info("SageMaker Inference result for the probability of positive class: "+body); //Write to DynamoDB AmazonDynamoDB ddbbuilder = AmazonDynamoDBClientBuilder .standard() .withRegion(REGION) .build(); DynamoDB ddbclient = new DynamoDB(ddbbuilder); Instant now = Instant.now(); Item ddbItem = new Item() .withKeyComponent("ContactId", contactId) .withKeyComponent("StartTime", now.toEpochMilli()) .withString("predictionTime", now.toString()) .withString("predictionBody", body); PutItemOutcome outcome = ddbclient.getTable(TABLE_ML_INFERENCE).putItem(ddbItem); logger.info("DynamoDB putItem result: "+outcome.toString()); } catch (UnsupportedEncodingException e) { logger.error("Failed to invoke SageMaker Endpoint: ", e); } catch (SdkClientException e) { logger.error("Failed to invoke SageMaker Endpoint: ", e); } catch (Exception e) { logger.error("Exception while writing to DDB: ", e); } } } else { logger.info("Skipping upload to S3. saveCallRecording was disabled or audio file has 0 bytes: " + kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString()); } } } /** * Create all objects necessary for KVS streaming from each track * * @param streamName * @param startFragmentNum * @param trackName * @param contactId * @return * @throws FileNotFoundException */ private KVSStreamTrackObject getKVSStreamTrackObject(String streamName, String startFragmentNum, String trackName, String contactId) throws FileNotFoundException { InputStream kvsInputStream = KVSUtils.getInputStreamFromKVS(streamName, REGION, startFragmentNum, getAWSCredentials(), START_SELECTOR_TYPE); StreamingMkvReader streamingMkvReader = StreamingMkvReader.createDefault(new InputStreamParserByteSource(kvsInputStream)); KVSContactTagProcessor tagProcessor = new KVSContactTagProcessor(contactId); FragmentMetadataVisitor fragmentVisitor = FragmentMetadataVisitor.create(Optional.of(tagProcessor)); String fileName = String.format("%s_%s_%s.raw", contactId, DATE_FORMAT.format(new Date()), trackName); Path saveAudioFilePath = Paths.get("/tmp", fileName); FileOutputStream fileOutputStream = new FileOutputStream(saveAudioFilePath.toString()); return new KVSStreamTrackObject(kvsInputStream, streamingMkvReader, tagProcessor, fragmentVisitor, saveAudioFilePath, fileOutputStream, trackName); } /** * @return AWS credentials to be used to connect to s3 (for fetching and uploading audio) and KVS */ private static AWSCredentialsProvider getAWSCredentials() { return DefaultAWSCredentialsProviderChain.getInstance(); } }