""" Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: MIT-0 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import boto3 import json import os from jose import jwk, jwt from jose.utils import base64url_decode import time import urllib.request import re region = os.environ['REGION'] user_pool_id = os.environ['USERPOOLID'] app_client_id = os.environ['APPCLIENTID'] keys_url = 'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format(region, user_pool_id) with urllib.request.urlopen(keys_url) as f: response = f.read() keys = json.loads(response.decode('utf-8'))['keys'] def lambda_handler(event, context): """Do not print the auth token unless absolutely necessary """ #print("Client token: " + event['authorizationToken']) token = event["authorizationToken"] tmp = event['methodArn'].split(':') apiGatewayArnTmp = tmp[5].split('/') awsAccountId = tmp[4] jwt_decode = JWT(token) decode = jwt_decode.decode() if not decode == False: principalId = jwt_decode.get_sub() policy = AuthPolicy(principalId, awsAccountId) policy.restApiId = apiGatewayArnTmp[0] policy.region = tmp[3] policy.stage = apiGatewayArnTmp[1] policy.allowAllMethods() context = { 'clientId': jwt_decode.get_client_id() } else: principalId = '' policy = AuthPolicy(principalId, awsAccountId) policy.restApiId = apiGatewayArnTmp[0] policy.region = tmp[3] policy.stage = apiGatewayArnTmp[1] policy.denyAllMethods() context = {} # Finally, build the policy authResponse = policy.build() authResponse['context'] = context return authResponse class JWT(object): token = "" claims = {} cognito = boto3.client('cognito-idp') user_details = {} def __init__(self, token): self.token = token def decode(self): # get the kid from the headers prior to verification headers = jwt.get_unverified_headers(self.token) kid = headers['kid'] # search for the kid in the downloaded public keys key_index = -1 for i in range(len(keys)): if kid == keys[i]['kid']: key_index = i break if key_index == -1: print('Public key not found in jwks.json') return False # construct the public key public_key = jwk.construct(keys[key_index]) # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = str(self.token).rsplit('.', 1) # decode the signature decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) # verify the signature if not public_key.verify(message.encode("utf8"), decoded_signature): print('Signature verification failed') return False print('Signature successfully verified') # since we passed the verification, we can now safely # use the unverified claims self.claims = jwt.get_unverified_claims(self.token) # additionally we can verify the token expiration if time.time() > self.claims['exp']: print('Token is expired') self.claims = {} return False # and the Audience (use claims['client_id'] if verifying an access token) if 'aud' in self.claims and self.claims['aud'] != app_client_id: print('Token was not issued for this audience') self.claims = {} return False # now we can use the claims user_detail = self.cognito.get_user(AccessToken=self.token) user_dict = {} if 'UserAttributes' in user_detail: for attribute in user_detail['UserAttributes']: user_dict[attribute['Name']] = (attribute['Value']) self.user_details = user_dict return self.claims def get_sub(self): return self.claims['sub'] def get_client_id(self): if 'custom:clientId' in self.user_details: return self.user_details['custom:clientId'] else: print('Custom client id not found for user') return '' class HttpVerb: GET = "GET" POST = "POST" PUT = "PUT" PATCH = "PATCH" HEAD = "HEAD" DELETE = "DELETE" OPTIONS = "OPTIONS" ALL = "*" class AuthPolicy(object): awsAccountId = "" principalId = "" version = "2012-10-17" pathRegex = "^[/.a-zA-Z0-9-\*]+$" allowMethods = [] denyMethods = [] def __init__(self, principal, awsAccountId): self.awsAccountId = awsAccountId self.principalId = principal self.allowMethods = [] self.denyMethods = [] def _addMethod(self, effect, verb, resource, conditions): if verb != "*" and not hasattr(HttpVerb, verb): raise NameError("Invalid HTTP verb " + verb + ". Allowed verbs in HttpVerb class") resourcePattern = re.compile(self.pathRegex) if not resourcePattern.match(resource): raise NameError("Invalid resource path: " + resource + ". Path should match " + self.pathRegex) if resource[:1] == "/": resource = resource[1:] resourceArn = ("arn:aws:execute-api:" + self.region + ":" + self.awsAccountId + ":" + self.restApiId + "/" + self.stage + "/" + verb + "/" + resource) if effect.lower() == "allow": self.allowMethods.append({ 'resourceArn' : resourceArn, 'conditions' : conditions }) elif effect.lower() == "deny": self.denyMethods.append({ 'resourceArn' : resourceArn, 'conditions' : conditions }) def _getEmptyStatement(self, effect): statement = { 'Action': 'execute-api:Invoke', 'Effect': effect[:1].upper() + effect[1:].lower(), 'Resource': [] } return statement def _getStatementForEffect(self, effect, methods): statements = [] if len(methods) > 0: statement = self._getEmptyStatement(effect) for curMethod in methods: if curMethod['conditions'] is None or len(curMethod['conditions']) == 0: statement['Resource'].append(curMethod['resourceArn']) else: conditionalStatement = self._getEmptyStatement(effect) conditionalStatement['Resource'].append(curMethod['resourceArn']) conditionalStatement['Condition'] = curMethod['conditions'] statements.append(conditionalStatement) statements.append(statement) return statements def allowAllMethods(self): self._addMethod("Allow", HttpVerb.ALL, "*", []) def denyAllMethods(self): self._addMethod("Deny", HttpVerb.ALL, "*", []) def allowMethod(self, verb, resource): self._addMethod("Allow", verb, resource, []) def denyMethod(self, verb, resource): self._addMethod("Deny", verb, resource, []) def allowMethodWithConditions(self, verb, resource, conditions): self._addMethod("Allow", verb, resource, conditions) def denyMethodWithConditions(self, verb, resource, conditions): self._addMethod("Deny", verb, resource, conditions) def build(self): if ((self.allowMethods is None or len(self.allowMethods) == 0) and (self.denyMethods is None or len(self.denyMethods) == 0)): raise NameError("No statements defined for the policy") policy = { 'principalId' : self.principalId, 'policyDocument' : { 'Version' : self.version, 'Statement' : [] } } policy['policyDocument']['Statement'].extend(self._getStatementForEffect("Allow", self.allowMethods)) policy['policyDocument']['Statement'].extend(self._getStatementForEffect("Deny", self.denyMethods)) return policy