// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT-0
using System;
using System.Reflection;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Linq;
using System.Threading.Tasks;
using System.Threading;
namespace csharp_example_extension
{
///
/// Lambda Extension API client
///
internal class ExtensionClient : IDisposable
{
#region HTTP header key names
///
/// HTTP header that is used to register a new extension name with Extension API
///
private const string LambdaExtensionNameHeader = "Lambda-Extension-Name";
///
/// HTTP header used to provide extension registration id
///
///
/// Registration endpoint reply will have this header value with a new id, assigned to this extension by the API.
/// All other endpoints will expect HTTP calls to have id header attached to all requests.
///
private const string LambdaExtensionIdHeader = "Lambda-Extension-Identifier";
///
/// HTTP header to report Lambda Extension error type string.
///
///
/// This header is used to report additional error details for Init and Shutdown errors.
///
private const string LambdaExtensionFunctionErrorTypeHeader = "Lambda-Extension-Function-Error-Type";
#endregion
#region Environment variable names
///
/// Environment variable that holds server name and port number for Extension API endpoints
///
private const string LambdaRuntimeApiAddress = "AWS_LAMBDA_RUNTIME_API";
#endregion
#region Instance properties
///
/// Extension id, which is assigned to this extension after the registration
///
public string Id { get; private set; }
#endregion
#region Constructor and readonly variables
///
/// Http client instance
///
/// This is an IDisposable object that must be properly disposed of,
/// thus implements interface too.
private readonly HttpClient httpClient = new HttpClient();
///
/// Extension name, calculated from the current executing assembly name
///
private readonly string extensionName;
///
/// Extension registration URL
///
private readonly Uri registerUrl;
///
/// Next event long poll URL
///
private readonly Uri nextUrl;
///
/// Extension initialization error reporting URL
///
private readonly Uri initErrorUrl;
///
/// Extension shutdown error reporting URL
///
private readonly Uri shutdownErrorUrl;
///
/// Constructor
///
public ExtensionClient(string extensionName)
{
this.extensionName = extensionName ?? throw new ArgumentNullException(nameof(extensionName), "Extension name cannot be null");
// Set infinite timeout so that underlying connection is kept alive
this.httpClient.Timeout = Timeout.InfiniteTimeSpan;
// Get Extension API service base URL from the environment variable
var apiUri = new UriBuilder(Environment.GetEnvironmentVariable(LambdaRuntimeApiAddress)).Uri;
// Common path for all Extension API URLs
var basePath = "2020-01-01/extension";
// Calculate all Extension API endpoints' URLs
this.registerUrl = new Uri(apiUri, $"{basePath}/register");
this.nextUrl = new Uri(apiUri, $"{basePath}/event/next");
this.initErrorUrl = new Uri(apiUri, $"{basePath}/init/error");
this.shutdownErrorUrl = new Uri(apiUri, $"{basePath}/exit/error");
}
#endregion
#region Public interface
///
/// Extension registration and event loop handling
///
/// Optional lambda extension that is invoked when extension has been successfully registered with AWS Lambda Extension API.
/// This function will be called exactly once if it is defined and ignored if this parameter is null.
/// Optional lambda extension that is invoked every time AWS Lambda Extension API reports a new event.
/// This function will be called once for each event during the entire lifetime of AWS Lambda function instance.
/// Optional lambda extension that is invoked when extension receives event from AWS LAmbda Extension API.
/// This function will be called exactly once if it is defined and ignored if this parameter is null.
/// Awaitable void
/// Unhandled exceptions thrown by and functions will be reported to AWS Lambda API with
/// /init/error and /exit/error calls, in any case will immediately exit after reporting the error.
/// Unhandled exceptions are logged to console and ignored, so that extension execution can continue.
///
public async Task ProcessEvents(Func onInit = null, Func onInvoke = null, Func onShutdown = null)
{
// Register extension with AWS Lambda Extension API to handle both INVOKE and SHUTDOWN events
await RegisterExtensionAsync(ExtensionEvent.INVOKE, ExtensionEvent.SHUTDOWN);
// If onInit function is defined, invoke it and report any unhandled exceptions
if (!await SafeInvoke(onInit, this.Id, ex => ReportErrorAsync(this.initErrorUrl, "Fatal.Unhandled", ex))) return;
// loop till SHUTDOWN event is received
var hasNext = true;
while (hasNext)
{
// get the next event type and details
var (type, payload) = await GetNextAsync();
switch (type)
{
case ExtensionEvent.INVOKE:
// invoke onInit function if one is defined and log unhandled exceptions
// event loop will continue even if there was an exception
await SafeInvoke(onInvoke, payload, onException: ex => {
Console.WriteLine($"[{this.extensionName}] Invoke handler threw an exception");
return Task.CompletedTask;
});
break;
case ExtensionEvent.SHUTDOWN:
// terminate the loop, invoke onShutdown function if there is any and report any unhandled exceptions to AWS Extension API
hasNext = false;
await SafeInvoke(onShutdown, this.Id, ex => ReportErrorAsync(this.shutdownErrorUrl, "Fatal.Unhandled", ex));
break;
default:
throw new ApplicationException($"Unexpected event type: {type}");
}
}
}
#endregion
#region Private methods
///
/// Register extension with Extension API
///
/// Event types to by notified with
/// Awaitable void
/// This method is expected to be called just once when extension is being registered with the Extension API.
private async Task RegisterExtensionAsync(params ExtensionEvent[] events)
{
// custom options for JsonSerializer to serialize ExtensionEvent enum values as strings, rather than integers
// thus we produce strongly typed code, which doesn't rely on strings
var options = new JsonSerializerOptions();
options.Converters.Add(new JsonStringEnumConverter());
// create Json content for this extension registration
using var content = new StringContent(JsonSerializer.Serialize(new {
events
}, options), Encoding.UTF8, "application/json");
// add extension name header value
content.Headers.Add(LambdaExtensionNameHeader, this.extensionName);
// POST call to Extension API
using var response = await this.httpClient.PostAsync(this.registerUrl, content);
// if POST call didn't succeed
if (!response.IsSuccessStatusCode)
{
// log details
Console.WriteLine($"[{this.extensionName}] Error response received for registration request: {await response.Content.ReadAsStringAsync()}");
// throw an unhandled exception, so that extension is terminated by Lambda runtime
response.EnsureSuccessStatusCode();
}
// get registration id from the response header
this.Id = response.Headers.GetValues(LambdaExtensionIdHeader).FirstOrDefault();
// if registration id is empty
if (string.IsNullOrEmpty(this.Id))
{
// throw an exception
throw new ApplicationException("Extension API register call didn't return a valid identifier.");
}
// configure all HttpClient to send registration id header along with all subsequent requests
this.httpClient.DefaultRequestHeaders.Add(LambdaExtensionIdHeader, this.Id);
}
///
/// Long poll for the next event from Extension API
///
/// Awaitable tuple having event type and event details fields
/// It is important to have httpClient.Timeout set to some value, that is longer than any expected wait time,
/// otherwise HttpClient will throw an exception when getting the next event details from the server.
private async Task<(ExtensionEvent type, string payload)> GetNextAsync()
{
// use GET request to long poll for the next event
var contentBody = await this.httpClient.GetStringAsync(this.nextUrl);
// use JsonDocument instead of JsonSerializer, since there is no need to construct the entire object
using var doc = JsonDocument.Parse(contentBody);
// extract eventType from the reply, convert it to ExtensionEvent enum and reply with the typed event type and event content details.
return new (Enum.Parse(doc.RootElement.GetProperty("eventType").GetString()), contentBody);
}
///
/// Report initialization or shutdown error
///
/// or
/// Error type string, e.g. Fatal.ConnectionError or any other meaningful type
/// Exception details
/// Awaitable void
/// This implementation will append name to for demonstration purposes
private async Task ReportErrorAsync(Uri url, string errorType, Exception exception)
{
using var content = new StringContent(string.Empty);
content.Headers.Add(LambdaExtensionIdHeader, this.Id);
content.Headers.Add(LambdaExtensionFunctionErrorTypeHeader, $"{errorType}.{exception.GetType().Name}");
using var response = await this.httpClient.PostAsync(url, content);
if (!response.IsSuccessStatusCode)
{
Console.WriteLine($"[{this.extensionName}] Error response received for {url.PathAndQuery}: {await response.Content.ReadAsStringAsync()}");
response.EnsureSuccessStatusCode();
}
}
///
/// Try to invoke and call if threw an exception
///
/// Function to be invoked. Do nothing if it is null.
/// Parameter to pass to the
/// Exception reporting function to be called in case of an exception. Can be null.
/// Awaitable boolean value. True if succeeded and False otherwise.
private async Task SafeInvoke(Func func, string param, Func onException)
{
try
{
await func?.Invoke(param);
return true;
}
catch (Exception ex)
{
await onException?.Invoke(ex);
return false;
}
}
#endregion
#region IDisposable implementation
///
/// Dispose of instance Disposable variables
///
public void Dispose()
{
// Quick and dirty implementation to propagate Dispose call to HttpClient instance
((IDisposable)httpClient).Dispose();
}
#endregion
}
}