# ##################################################################################################################### # Copyright 2020-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance # # with the License. A copy of the License is located at # # # # http://www.apache.org/licenses/LICENSE-2.0 # # # # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # # and limitations under the License. # # ##################################################################################################################### import json from os import environ import boto3 from shared_util import custom_boto_config, custom_logging, service_helper logger = custom_logging.get_logger(__name__) class EnvironmentVariableError(Exception): pass def get_service_client(service_name): """Get the global service boto3 client""" return service_helper.get_service_client(service_name) def get_quicksight_client(): """Get the global quicksight boto3 client""" return get_service_client("quicksight") def get_sts_client(): """Get the global sts boto3 client""" return get_service_client("sts") def get_aws_partition(): """ Get the caller's AWS partion by driving it from AWS region :return: partition name for the current AWS region (e.g. aws) """ region_name = environ.get("AWS_REGION") china_region_name_prefix = "cn" us_gov_cloud_region_name_prefix = "us-gov" aws_regions_partition = "aws" aws_china_regions_partition = "aws-cn" aws_us_gov_cloud_regions_partition = "aws-us-gov" # China regions if region_name.startswith(china_region_name_prefix): return aws_china_regions_partition # AWS GovCloud(US) Regions elif region_name.startswith(us_gov_cloud_region_name_prefix): return aws_us_gov_cloud_regions_partition else: return aws_regions_partition def get_aws_region(): """ Get the caller's AWS region from the environment variable AWS_REGION :return: the AWS region name (e.g. us-east-1) """ region = environ.get("AWS_REGION") if not region: raise EnvironmentVariableError("Missing AWS_REGION environment variable.") return region def get_aws_account_id(): """ Get the caller's AWS account ID :return: The AWS account ID """ sts_client = get_sts_client() identity = sts_client.get_caller_identity() return identity.get("Account")