using System; using System.IdentityModel.Tokens.Jwt; using System.Net; using System.Security.Claims; using System.Threading.Tasks; using Amazon.Lambda.APIGatewayEvents; using Amazon.Lambda.Core; using Amazon.Lambda.Serialization.Json; using Amazon.Runtime; using Amazon.XRay.Recorder.Handlers.AwsSdk; using Microsoft.Extensions.Configuration; using Microsoft.IdentityModel.Protocols; using Microsoft.IdentityModel.Protocols.OpenIdConnect; using Microsoft.IdentityModel.Tokens; // Assembly attribute to enable the Lambda function's JSON input to be converted into a .NET class. [assembly: LambdaSerializer(typeof(JsonSerializer))] namespace ImageRecognition.Communication.Functions { public class Functions { private const string AUTHORIZATION_HEADER = "Authorization"; private const string BEARER_PREFIX = "Bearer"; private const string TABLE_NAME_ENV = "COMMUNICATION_TABLE"; private TokenValidationParameters _jwtValidationParameters; private readonly CommunicationManager _manager; /// /// Default constructor that Lambda will invoke. /// public Functions() { AWSSDKHandler.RegisterXRayForAllServices(); _manager = CommunicationManager.CreateManager(Environment.GetEnvironmentVariable(TABLE_NAME_ENV)); } /// /// Verify JWT token in Authorization header and if valid allow connection. /// /// /// public async Task OnConnect(APIGatewayProxyRequest request, ILambdaContext context) { if (_jwtValidationParameters == null) _jwtValidationParameters = await CreateTokenValidationParameters(context); try { var username = ValidateAndGetUsername(request, context); var domainName = request.RequestContext.DomainName; var stage = request.RequestContext.Stage; var endpoint = $"https://{domainName}/{stage}"; if (string.IsNullOrEmpty(username)) { context.Logger.LogLine("Error, no username claim found in JWT token"); return new APIGatewayProxyResponse { StatusCode = (int) HttpStatusCode.Unauthorized }; } context.Logger.LogLine( $"Login with connection id: {request.RequestContext.ConnectionId}, Endpoint: {endpoint}, Username: {username}"); await _manager.LoginAsync(request.RequestContext.ConnectionId, endpoint, username); return new APIGatewayProxyResponse { StatusCode = (int) HttpStatusCode.OK }; } catch { return new APIGatewayProxyResponse { StatusCode = (int) HttpStatusCode.Unauthorized }; } } public async Task OnDisconnect(APIGatewayProxyRequest request, ILambdaContext context) { context.Logger.LogLine($"Logoff with connection id: {request.RequestContext.ConnectionId}"); await _manager.LogoffAsync(request.RequestContext.ConnectionId); var response = new APIGatewayProxyResponse { StatusCode = (int) HttpStatusCode.OK }; return response; } public string ValidateAndGetUsername(APIGatewayProxyRequest request, ILambdaContext context) { string authorization; if (!request.QueryStringParameters.TryGetValue(AUTHORIZATION_HEADER, out authorization)) { context.Logger.LogLine("Error, no Authorization header found"); throw new Exception("Error, no Authorization header found"); } if (authorization.StartsWith(BEARER_PREFIX, StringComparison.OrdinalIgnoreCase)) authorization = authorization.Substring(BEARER_PREFIX.Length + 1); ClaimsPrincipal user; try { SecurityToken validatedToken; user = new JwtSecurityTokenHandler().ValidateToken(authorization, _jwtValidationParameters, out validatedToken); if (DateTime.UtcNow < validatedToken.ValidFrom || validatedToken.ValidTo < DateTime.UtcNow) { Console.WriteLine( $"Error, JWT Token expired. Token was valid from {validatedToken.ValidFrom} to {validatedToken.ValidTo}"); throw new Exception("JWT Token expired"); } } catch (Exception e) { Console.WriteLine($"Error validating JWT token: {e.Message}"); throw; } return user.FindFirst("username")?.Value; } private async Task CreateTokenValidationParameters(ILambdaContext context) { context.Logger.LogLine("Loading user pool configuration from SSM Parameter Store."); var configuration = new ConfigurationBuilder() .AddSystemsManager("/ImageRecognition") .Build(); var region = configuration["AWS:Region"]; if (string.IsNullOrEmpty(region)) region = FallbackRegionFactory.GetRegionEndpoint().SystemName; var userPoolId = configuration["AWS:UserPoolId"]; var userPoolClientId = configuration["AWS:UserPoolClientId"]; context.Logger.LogLine("Configuring JWT Validation parameters"); var openIdConfigurationUrl = $"https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/openid-configuration"; var configurationManager = new ConfigurationManager(openIdConfigurationUrl, new OpenIdConnectConfigurationRetriever()); context.Logger.LogLine($"Loading open id configuration from {openIdConfigurationUrl}"); var openIdConfig = await configurationManager.GetConfigurationAsync(); var validIssuer = $"https://cognito-idp.{region}.amazonaws.com/{userPoolId}"; context.Logger.LogLine($"Valid Issuer: {validIssuer}"); context.Logger.LogLine($"Valid Audiences: {userPoolClientId}"); return new TokenValidationParameters { ValidIssuer = validIssuer, ValidateAudience = false, ValidAudiences = new[] {userPoolClientId}, IssuerSigningKeys = openIdConfig.SigningKeys }; } } }