#!/usr/bin/python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. # This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. # See the License for the specific language governing permissions and limitations under the License. import json import logging import os import boto3 from s3_factory import S3DocumentManager PARTITION_TO_MAIN_REGION = {"commercial": "us-east-1", "govcloud": "us-gov-west-1", "china": "cn-north-1"} PARTITIONS = ["commercial", "china", "govcloud"] FILE_TO_S3_PATH = {"instances": "instances/instances.json", "feature_whitelist": "features/feature_whitelist.json"} def get_aws_regions(partition): ec2 = boto3.client("ec2", region_name=PARTITION_TO_MAIN_REGION[partition]) return set(r.get("RegionName") for r in ec2.describe_regions().get("Regions")) def retrieve_sts_credentials(credentials, client_region, regions): """ Given credentials from cli, returns a json credentials object. { 'us-east-1': { 'aws_access_key_id': 'sjkdnf', 'aws_secret_access_key': 'ksjdfkjsd', 'aws_session_token': 'skajdfksdjn' } ... } :param credentials: STS credential endpoint, in the format ,,,. Could be specified multiple times :param client_region: region of the client that is assuming the role :return: sts credentials json """ sts_credentials = {} for credential in credentials: region, endpoint, arn, external_id = credential sts = boto3.client("sts", region_name=client_region, endpoint_url=endpoint) assumed_role_object = sts.assume_role( RoleArn=arn, ExternalId=external_id, RoleSessionName=region + "-upload_instance_slot_map_sts_session" ) sts_credentials[region] = { "aws_access_key_id": assumed_role_object["Credentials"].get("AccessKeyId"), "aws_secret_access_key": assumed_role_object["Credentials"].get("SecretAccessKey"), "aws_session_token": assumed_role_object["Credentials"].get("SessionToken"), } if sts_credentials.get("default"): for region in regions: if region not in sts_credentials: sts_credentials[region] = sts_credentials["default"] return sts_credentials def generate_rollback_data(regions, dest_bucket, files, sts_credentials): rollback_data = {} for region in regions: bucket_name = dest_bucket.format(region=region) rollback_data[bucket_name] = {"region": region, "files": {}} doc_manager = S3DocumentManager(region, sts_credentials.get(region)) for file_type in files: s3_path = FILE_TO_S3_PATH.get(file_type, file_type) version = doc_manager.get_current_version( dest_bucket.format(region=region), s3_path, raise_on_object_not_found=False, ) rollback_data[bucket_name]["files"][s3_path] = version logging.info("Rollback data:\n%s", json.dumps(rollback_data, indent=2)) rollback_file_name = "rollback-data.json" with open(rollback_file_name, "w", encoding="utf-8") as outfile: json.dump(rollback_data, outfile, indent=2) logging.info("Rollback data file created to: %s", f"{os.getcwd()}/{rollback_file_name}") return rollback_data