# # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 # import json from operator import indexOf import time import os import boto3 import urllib.request from jose import jwk, jwt from jose.utils import base64url_decode from boto3.dynamodb.conditions import Key import re def lambda_handler(event, context): region = os.environ['AWS_REGION'] token = event['authorizationToken'] if token.find('Bearer ') != -1: token = token[len('Bearer '):len(token)] payload = jwt.get_unverified_claims(token) tenantId = payload['custom:tenant-id'] dynamodbClient = boto3.resource('dynamodb', region_name='us-west-2') table = dynamodbClient.Table('AuthInfo') result = table.query( IndexName = 'tenantId-Index', KeyConditionExpression=Key('tenantId').eq(tenantId) ) if result['Count'] == 0: result = table.query( KeyConditionExpression=Key('tenant_path').eq('app') ) if result['Count'] == 0: raise Exception('Unauthorized') userpool_id = result['Items'][0]['user_pool_id'] app_client_id = result['Items'][0]['client_id'] keys_url = 'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format(region, userpool_id) # get the kid from the headers prior to verification headers = jwt.get_unverified_headers(token) #dev # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = str(token).rsplit('.', 1) # decode the signature decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) #dev with urllib.request.urlopen(keys_url) as f: response = f.read() keys = json.loads(response.decode('utf-8'))['keys'] kid = headers['kid'] # search for the kid in the downloaded public keys key_index = -1 for i in range(len(keys)): if kid == keys[i]['kid']: key_index = i break if key_index == -1: print('Public key not found in jwks.json') raise Exception('Unauthorized') #return False # construct the public key public_key = jwk.construct(keys[key_index]) # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = str(token).rsplit('.', 1) # decode the signature decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) # verify the signature if not public_key.verify(message.encode("utf8"), decoded_signature): print('Signature verification failed') raise Exception('Unauthorized') #return False print('Signature successfully verified') # since we passed the verification, we can now safely # use the unverified claims claims = jwt.get_unverified_claims(token) # additionally we can verify the token expiration if time.time() > claims['exp']: print('Token is expired') raise Exception('Unauthorized') #return False # and the Audience (use claims['client_id'] if verifying an access token) if claims['aud'] != app_client_id: print('Token was not issued for this audience') raise Exception('Unauthorized') #return False # now we can use the claims print(claims) #return claims tmp = event['methodArn'].split(':') apiGatewayArnTmp = tmp[5].split('/') awsAccountId = tmp[4] principalId = payload['cognito:username'] policy = AuthPolicy(principalId, awsAccountId) policy.restApiId = apiGatewayArnTmp[0] policy.region = tmp[3] policy.stage = apiGatewayArnTmp[1] policy.allowAllMethods() authResponse = policy.build() return authResponse class HttpVerb: GET = 'GET' POST = 'POST' PUT = 'PUT' PATCH = 'PATCH' HEAD = 'HEAD' DELETE = 'DELETE' OPTIONS = 'OPTIONS' ALL = '*' class AuthPolicy(object): # The AWS account id the policy will be generated for. This is used to create the method ARNs. awsAccountId = '' # The principal used for the policy, this should be a unique identifier for the end user. principalId = '' # The policy version used for the evaluation. This should always be '2012-10-17' version = '2012-10-17' # The regular expression used to validate resource paths for the policy pathRegex = '^[/.a-zA-Z0-9-\*]+$' '''Internal lists of allowed and denied methods. These are lists of objects and each object has 2 properties: A resource ARN and a nullable conditions statement. The build method processes these lists and generates the approriate statements for the final policy. ''' allowMethods = [] denyMethods = [] """Replace the placeholder value with a default API Gateway API id to be used in the policy. Beware of using '*' since it will not simply mean any API Gateway API id, because stars will greedily expand over '/' or other separators. See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details.""" restApiId = "<<restApiId>>" """Replace the placeholder value with a default region to be used in the policy. Beware of using '*' since it will not simply mean any region, because stars will greedily expand over '/' or other separators. See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details.""" region = "<<region>>" """Replace the placeholder value with a default stage to be used in the policy. Beware of using '*' since it will not simply mean any stage, because stars will greedily expand over '/' or other separators. See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details.""" stage = "<<stage>>" def __init__(self, principal, awsAccountId): self.awsAccountId = awsAccountId self.principalId = principal self.allowMethods = [] self.denyMethods = [] def _addMethod(self, effect, verb, resource, conditions): '''Adds a method to the internal lists of allowed or denied methods. Each object in the internal list contains a resource ARN and a condition statement. The condition statement can be null.''' if verb != '*' and not hasattr(HttpVerb, verb): raise NameError('Invalid HTTP verb ' + verb + '. Allowed verbs in HttpVerb class') resourcePattern = re.compile(self.pathRegex) if not resourcePattern.match(resource): raise NameError('Invalid resource path: ' + resource + '. Path should match ' + self.pathRegex) if resource[:1] == '/': resource = resource[1:] resourceArn = 'arn:aws:execute-api:{}:{}:{}/{}/{}/{}'.format(self.region, self.awsAccountId, self.restApiId, self.stage, verb, resource) if effect.lower() == 'allow': self.allowMethods.append({ 'resourceArn': resourceArn, 'conditions': conditions }) elif effect.lower() == 'deny': self.denyMethods.append({ 'resourceArn': resourceArn, 'conditions': conditions }) def _getEmptyStatement(self, effect): '''Returns an empty statement object prepopulated with the correct action and the desired effect.''' statement = { 'Action': 'execute-api:Invoke', 'Effect': effect[:1].upper() + effect[1:].lower(), 'Resource': [] } return statement def _getStatementForEffect(self, effect, methods): '''This function loops over an array of objects containing a resourceArn and conditions statement and generates the array of statements for the policy.''' statements = [] if len(methods) > 0: statement = self._getEmptyStatement(effect) for curMethod in methods: if curMethod['conditions'] is None or len(curMethod['conditions']) == 0: statement['Resource'].append(curMethod['resourceArn']) else: conditionalStatement = self._getEmptyStatement(effect) conditionalStatement['Resource'].append(curMethod['resourceArn']) conditionalStatement['Condition'] = curMethod['conditions'] statements.append(conditionalStatement) if statement['Resource']: statements.append(statement) return statements def allowAllMethods(self): '''Adds a '*' allow to the policy to authorize access to all methods of an API''' self._addMethod('Allow', HttpVerb.ALL, '*', []) def denyAllMethods(self): '''Adds a '*' allow to the policy to deny access to all methods of an API''' self._addMethod('Deny', HttpVerb.ALL, '*', []) def allowMethod(self, verb, resource): '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed methods for the policy''' self._addMethod('Allow', verb, resource, []) def denyMethod(self, verb, resource): '''Adds an API Gateway method (Http verb + Resource path) to the list of denied methods for the policy''' self._addMethod('Deny', verb, resource, []) def allowMethodWithConditions(self, verb, resource, conditions): '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed methods and includes a condition for the policy statement. More on AWS policy conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition''' self._addMethod('Allow', verb, resource, conditions) def denyMethodWithConditions(self, verb, resource, conditions): '''Adds an API Gateway method (Http verb + Resource path) to the list of denied methods and includes a condition for the policy statement. More on AWS policy conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition''' self._addMethod('Deny', verb, resource, conditions) def build(self): '''Generates the policy document based on the internal lists of allowed and denied conditions. This will generate a policy with two main statements for the effect: one statement for Allow and one statement for Deny. Methods that includes conditions will have their own statement in the policy.''' if ((self.allowMethods is None or len(self.allowMethods) == 0) and (self.denyMethods is None or len(self.denyMethods) == 0)): raise NameError('No statements defined for the policy') policy = { 'principalId': self.principalId, 'policyDocument': { 'Version': self.version, 'Statement': [] } } policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Allow', self.allowMethods)) policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Deny', self.denyMethods)) return policy