package com.amazonaws.services.kinesisanalytics; import com.amazonaws.services.lambda.AWSLambdaAsync; import com.amazonaws.services.lambda.AWSLambdaAsyncClientBuilder; import com.amazonaws.services.lambda.model.InvocationType; import com.amazonaws.services.lambda.model.InvokeRequest; import org.apache.commons.codec.digest.DigestUtils; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; /** * Aws lambda sink takes json as string & invokes lambda function as needed. Use with methods to tune behaviour. *

* {@link AwsLambdaSink#withMaxBufferTimeInMillis(long)} * {@link AwsLambdaSink#withMaxConcurrency(int)} * {@link AwsLambdaSink#withMaxConcurrency(int)} */ public class AwsLambdaSink extends RichSinkFunction implements CheckpointedFunction { private static final Logger LOG = LoggerFactory.getLogger(AwsLambdaSink.class); private final String functionName; private final Class recordType; private boolean skipBadRecords = true; private int maxRecordsPerFunctionCall = 100; private int maxConcurrency = 5; // Maximum concurrent lambda invocation private long maxBufferTimeInMillis = 60 * 1000; // 1 minute private static final long MAX_PAYLOAD_BYTES = 256 * 1000; // 256 KB for async lambda invocation private List bufferedRecords; private long buffedBytes; private long lastPublishTime; private transient ListState checkPointedState; private AWSLambdaAsync awsLambdaAsync; private final ObjectMapper jsonParser = new ObjectMapper(); /** * Creates new AWSLambdaSink * * @param functionName the lambda function to invoke, Name, Alias or ARN is supported * @param recordType the type of the records this sink will process */ public AwsLambdaSink(String functionName, Class recordType) { this.functionName = functionName; this.recordType = recordType; } /** * Sets awsLambdaAsync Client * * @param asyncClient Pre-built awsLambdaAsync client to use. Overrides default created with {@link AWSLambdaAsyncClientBuilder#defaultClient()} * @return this object for chaining */ public AwsLambdaSink withAwsLambdaClient(AWSLambdaAsync asyncClient) { this.awsLambdaAsync = asyncClient; return this; } /** * Sets maxRecordsPerFunctionCall * * @param maxRecordsPerFunctionCall the max records to pass per lambda function invocation. Default: 100 * @return this object for chaining */ public AwsLambdaSink withMaxRecordsPerFunctionCall(int maxRecordsPerFunctionCall) { this.maxRecordsPerFunctionCall = maxRecordsPerFunctionCall; return this; } /** * Sets maxBufferTimeInMillis * * @param maxBufferTimeInMillis the max buffer time to keep records in memory. Set this value according to the lateness allowed. Default: 1 minute * @return this object for chaining */ public AwsLambdaSink withMaxBufferTimeInMillis(long maxBufferTimeInMillis) { this.maxBufferTimeInMillis = maxBufferTimeInMillis; return this; } /** * Sets maxConcurrency * * @param maxConcurrency Maximum number of lambda function calls to be done concurrently. Default: 5 * @return this object for chaining */ public AwsLambdaSink withMaxConcurrency(int maxConcurrency) { this.maxConcurrency = maxConcurrency; return this; } /** * Sets skipBadRecords * * @param skipBadRecords Continue processing by skipping bad records. Set this value to FALSE for failing on bad record. Default: true * @return this object for chaining */ public AwsLambdaSink withSkipBadRecords(boolean skipBadRecords) { this.skipBadRecords = skipBadRecords; return this; } @Override public void open(Configuration parameters) { if (this.awsLambdaAsync == null) this.awsLambdaAsync = AWSLambdaAsyncClientBuilder.defaultClient(); this.bufferedRecords = new ArrayList<>(); this.lastPublishTime = System.currentTimeMillis(); LOG.debug("Opening new sink. lastPublishTime set to " + lastPublishTime); } @Override public void invoke(T value, Context context) throws Exception { // Ensure all records are under max lambda payload size byte[] valueAsBytes = jsonParser.writeValueAsBytes(value); // All json parsing exceptions will be thrown from here early if (valueAsBytes.length > MAX_PAYLOAD_BYTES) { if (skipBadRecords) { LOG.warn("Skipping record with MD5 hash " + DigestUtils.md5Hex(valueAsBytes) + " as it exceeds max allowed lambda function payload size."); return; } else throw new RuntimeException("Record with MD5 hash " + DigestUtils.md5Hex(valueAsBytes) + " exceeds max allowed lambda function payload size."); } // Add new received record to the buffer bufferedRecords.add(value); this.buffedBytes += valueAsBytes.length; if (shouldPublish()) { List> batches = new ArrayList<>(); int currentBatchIndex = 0; int recordsInCurrentBatch = 0; long bytesInCurrentBatch = 0; batches.add(new ArrayList<>()); for (T bufferedRecord : bufferedRecords) { String record = jsonParser.writeValueAsString(bufferedRecord); recordsInCurrentBatch++; bytesInCurrentBatch += record.getBytes().length; if (recordsInCurrentBatch > maxRecordsPerFunctionCall || bytesInCurrentBatch > (MAX_PAYLOAD_BYTES - (bufferedRecords.size() * 2L) - 4) // current batch will be converted as array which adds 4 bytes for bracket & 2 bytes for each comma // {rec1} = 20 bytes, {rec2} = 40 bytes will be converted to // [{rec1},{rec2}] here array square bracket adds 2 character & one comma per record which all occupies 2 bytes each ) { batches.add(++currentBatchIndex, new ArrayList<>()); recordsInCurrentBatch = 1; bytesInCurrentBatch = record.getBytes().length; } batches.get(currentBatchIndex).add(bufferedRecord); } LOG.info("Flushing " + batches.size() + " buffered batches lastPublishTime: " + lastPublishTime + ", bufferedRecords: " + bufferedRecords.size() + ", buffedBytes: " + buffedBytes); batches.parallelStream().forEach(batch -> { try { awsLambdaAsync.invoke(new InvokeRequest() .withFunctionName(functionName) .withInvocationType(InvocationType.Event) .withPayload(jsonParser.writeValueAsString(batch))); } catch (JsonProcessingException e) { // Ignore, This is unreachable } }); // Reset all once published bufferedRecords.clear(); this.buffedBytes = 0; this.lastPublishTime = System.currentTimeMillis(); } } /** * Publish records from buffer when *

  • Buffered records are more than max records per function call * maximum concurrent calls we can make
  • *
  • Total bytes accumulated are more than max payload (256 KB) times concurrency
  • *
  • Maximum lateness has reached (i.e. max buffer time in milliseconds)
  • * * @return true if any of above defined condition is met, false otherwise */ private boolean shouldPublish() { boolean maxRecordsReached = bufferedRecords.size() >= maxRecordsPerFunctionCall * maxConcurrency; boolean maxBufferedBytesReached = (buffedBytes / MAX_PAYLOAD_BYTES) >= maxConcurrency; boolean maxTimeInBufferReached = (lastPublishTime + maxBufferTimeInMillis) <= System.currentTimeMillis(); LOG.debug("Should publish - maxRecordsReached: " + maxRecordsReached + ", maxBufferedBytesReached: " + maxBufferedBytesReached + ", maxTimeInBufferReached: " + maxTimeInBufferReached); return maxRecordsReached || maxBufferedBytesReached || maxTimeInBufferReached; } @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { checkPointedState.clear(); for (T element : bufferedRecords) { checkPointedState.add(element); } } @Override public void initializeState(FunctionInitializationContext context) throws Exception { ListStateDescriptor descriptor = new ListStateDescriptor<>("recordList", recordType); checkPointedState = context.getOperatorStateStore().getListState(descriptor); if (context.isRestored()) { for (T element : checkPointedState.get()) { bufferedRecords.add(element); } } } }