# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import json import logging from awscli.customizations.commands import BasicCommand from awscli.customizations.emrcontainers.constants \ import TRUST_POLICY_STATEMENT_FORMAT, \ TRUST_POLICY_STATEMENT_ALREADY_EXISTS, \ TRUST_POLICY_UPDATE_SUCCESSFUL from awscli.customizations.emrcontainers.base36 import Base36 from awscli.customizations.emrcontainers.eks import EKS from awscli.customizations.emrcontainers.iam import IAM from awscli.customizations.utils import uni_print, get_policy_arn_suffix LOG = logging.getLogger(__name__) # Method to parse the arguments to get the region value def get_region(session, parsed_globals): region = parsed_globals.region if region is None: region = session.get_config_variable('region') return region def check_if_statement_exists(expected_statement, actual_assume_role_document): if actual_assume_role_document is None: return False existing_statements = actual_assume_role_document.get("Statement", []) for existing_statement in existing_statements: matches = check_if_dict_matches(expected_statement, existing_statement) if matches: return True return False def check_if_dict_matches(expected_dict, actual_dict): if len(expected_dict) != len(actual_dict): return False for key in expected_dict: key_str = str(key) val = expected_dict[key_str] if isinstance(val, dict): if not check_if_dict_matches(val, actual_dict.get(key_str, {})): return False else: if key_str not in actual_dict or actual_dict[key_str] != str(val): return False return True class UpdateRoleTrustPolicyCommand(BasicCommand): NAME = 'update-role-trust-policy' DESCRIPTION = BasicCommand.FROM_FILE( 'emr-containers', 'update-role-trust-policy', '_description.rst' ) ARG_TABLE = [ { 'name': 'cluster-name', 'help_text': ("Specify the name of the Amazon EKS cluster with " "which the IAM Role would be used."), 'required': True }, { 'name': 'namespace', 'help_text': ("Specify the namespace from the Amazon EKS cluster " "with which the IAM Role would be used."), 'required': True }, { 'name': 'role-name', 'help_text': ("Specify the IAM Role name that you want to use" "with Amazon EMR on EKS."), 'required': True }, { 'name': 'iam-endpoint', 'no_paramfile': True, 'help_text': ("The IAM endpoint to call for updating the role " "trust policy. This is optional and should only be" "specified when a custom endpoint should be called" "for IAM operations."), 'required': False }, { 'name': 'dry-run', 'action': 'store_true', 'default': False, 'help_text': ("Print the merged trust policy document to" "stdout instead of updating the role trust" "policy directly."), 'required': False } ] def _run_main(self, parsed_args, parsed_globals): """Call to run the commands""" self._cluster_name = parsed_args.cluster_name self._namespace = parsed_args.namespace self._role_name = parsed_args.role_name self._region = get_region(self._session, parsed_globals) self._endpoint_url = parsed_args.iam_endpoint self._dry_run = parsed_args.dry_run result = self._update_role_trust_policy(parsed_globals) uni_print(result) uni_print("\n") return 0 def _update_role_trust_policy(self, parsed_globals): """Method to update trust policy if not done already""" base36 = Base36() eks_client = EKS(self._session.create_client( 'eks', region_name=self._region, verify=parsed_globals.verify_ssl )) account_id = eks_client.get_account_id(self._cluster_name) oidc_provider = eks_client.get_oidc_issuer_id(self._cluster_name) base36_encoded_role_name = base36.encode(self._role_name) LOG.debug('Base36 encoded role name: %s', base36_encoded_role_name) trust_policy_statement = json.loads(TRUST_POLICY_STATEMENT_FORMAT % { "AWS_ACCOUNT_ID": account_id, "OIDC_PROVIDER": oidc_provider, "NAMESPACE": self._namespace, "BASE36_ENCODED_ROLE_NAME": base36_encoded_role_name, "AWS_PARTITION": get_policy_arn_suffix(self._region) }) LOG.debug('Computed Trust Policy Statement:\n%s', json.dumps( trust_policy_statement, indent=2)) iam_client = IAM(self._session.create_client( 'iam', region_name=self._region, endpoint_url=self._endpoint_url, verify=parsed_globals.verify_ssl )) assume_role_document = iam_client.get_assume_role_policy( self._role_name) matches = check_if_statement_exists(trust_policy_statement, assume_role_document) if not matches: LOG.debug('Role %s does not have the required trust policy ', self._role_name) existing_statements = assume_role_document.get("Statement") if existing_statements is None: assume_role_document["Statement"] = [trust_policy_statement] else: existing_statements.append(trust_policy_statement) if self._dry_run: return json.dumps(assume_role_document, indent=2) else: LOG.debug('Updating trust policy of role %s', self._role_name) iam_client.update_assume_role_policy(self._role_name, assume_role_document) return TRUST_POLICY_UPDATE_SUCCESSFUL % self._role_name else: return TRUST_POLICY_STATEMENT_ALREADY_EXISTS % self._role_name