#!/usr/bin/env python3 import configparser import datetime import json import os import pathlib import subprocess import sys import time import boto3 profile_filter = set(sys.argv[1:]) profile_variable = os.environ.get('AWS_PROFILE') if profile_variable: profile_filter.add(profile_variable) home_folder = pathlib.Path.home() cli_cache_folder = f'{home_folder}/.aws/cli/cache' aws_config_filename = f'{home_folder}/.aws/config' aws_credentials_filename = f'{home_folder}/.aws/credentials' def get_config(config_filename): config = configparser.ConfigParser() config.read(config_filename) return config def get_profile(section): if 'profile' not in section: return profile = section.replace('profile ', '') azure_tenant_id = aws_config.get(section, 'azure_tenant_id', fallback=None) okta_profile = aws_config.get(section, 'okta_profile', fallback=None) okta_account_id = aws_config.get(section, 'okta_account_id', fallback=None) okta_role_name = aws_config.get(section, 'okta_role_name', fallback=None) sso_account_id = aws_config.get(section, 'sso_account_id', fallback=None) sso_role_name = aws_config.get(section, 'sso_role_name', fallback=None) sso_start_url = aws_config.get(section, 'sso_start_url', fallback=None) source_profile = aws_config.get(section, 'source_profile', fallback=None) role_arn = aws_config.get(section, 'role_arn', fallback=None) if azure_tenant_id: return get_azure_credentials, profile, azure_tenant_id elif okta_profile and okta_account_id and okta_role_name: return get_okta_credentials, profile, okta_account_id, okta_role_name, okta_profile elif sso_account_id and sso_role_name and sso_start_url: return get_sso_credentials, profile, sso_account_id, sso_role_name elif source_profile and role_arn: arn_components = role_arn.split(':') account_id = arn_components[4] role_name = arn_components[5].replace('role/', '') extra = source_profile, role_arn return get_sts_credentials, profile, account_id, role_name, extra else: return get_vault_credentials, profile def run_sso_refresh_command(profile): command = ('aws', 'sts', '--profile', profile, 'get-caller-identity') subprocess.run(command, stdout=subprocess.PIPE, check=True) def run_sso_login_command(profile): command = ('aws', 'sso', '--profile', profile, 'login') subprocess.run(command, stdout=subprocess.PIPE, check=True) def get_azure_credentials(profile, account, role, extra): command = ('aws-azure-login', '--force-refresh', '--no-prompt') subprocess.run(command, check=True) new_aws_credentials = get_config(aws_credentials_filename) credentials = { 'accessKeyId': new_aws_credentials.get(profile, 'aws_access_key_id'), 'secretAccessKey': new_aws_credentials.get(profile, 'aws_secret_access_key'), 'sessionToken': new_aws_credentials.get(profile, 'aws_session_token'), 'expiration': new_aws_credentials.get(profile, 'aws_expiration'), } return profile, credentials def get_okta_credentials(profile, account, role, extra): okta_profile = extra role_arn = f'arn:aws:iam::{account}:role/{role}' command = ('gimme-aws-creds', '--profile', okta_profile, '--roles', role_arn) output = subprocess.run(command, stdout=subprocess.PIPE, check=True) output_credentials = json.loads(output.stdout)['credentials'] expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=1) credentials = { 'accessKeyId': output_credentials['aws_access_key_id'], 'secretAccessKey': output_credentials['aws_secret_access_key'], 'sessionToken': output_credentials['aws_session_token'], 'expiration': expiration.isoformat() + 'Z', } return profile, credentials def get_sso_credentials(profile, account, role, extra): try: run_sso_refresh_command(profile) except subprocess.CalledProcessError: run_sso_login_command(profile) run_sso_refresh_command(profile) time.sleep(2.0) cache_filenames = (os.path.join(cli_cache_folder, f) for f in os.listdir(cli_cache_folder)) cache_credentials = [json.load(open(f)) for f in cache_filenames] newest_credentials = sorted(cache_credentials, key=lambda c: c['Credentials']['Expiration'])[-1] credentials = { 'accessKeyId': newest_credentials['Credentials']['AccessKeyId'], 'secretAccessKey': newest_credentials['Credentials']['SecretAccessKey'], 'sessionToken': newest_credentials['Credentials']['SessionToken'], 'expiration': newest_credentials['Credentials']['Expiration'].replace('UTC', 'Z'), } return profile, credentials def get_sts_credentials(profile, account, role, extra): source_profile, role_arn = extra session = boto3.session.Session(profile_name=source_profile) sts = session.client('sts') response = sts.assume_role(RoleArn=role_arn, RoleSessionName='aws-assume-role') credentials = { 'accessKeyId': response['Credentials']['AccessKeyId'], 'secretAccessKey': response['Credentials']['SecretAccessKey'], 'sessionToken': response['Credentials']['SessionToken'], 'expiration': response['Credentials']['Expiration'].isoformat() + 'Z', } return profile, credentials def get_vault_credentials(profile, account, role, extra): command = ('aws-vault', 'exec', profile, 'env') vault_process = subprocess.run(command, stdout=subprocess.PIPE, encoding='utf-8') split_lines = (n.strip().split('=') for n in vault_process.stdout.split()) variable_lines = ((n[0], '='.join(n[1:])) for n in split_lines) aws_variable_lines = (n for n in variable_lines if n and n[0].startswith('AWS')) variables = dict(aws_variable_lines) expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=12) credentials = { 'accessKeyId': variables['AWS_ACCESS_KEY_ID'], 'secretAccessKey': variables['AWS_SECRET_ACCESS_KEY'], 'sessionToken': variables['AWS_SESSION_TOKEN'], 'expiration': expiration.isoformat() + 'Z', } return profile, credentials def get_credentials(process, profile, account=None, role=None, extra=None): if account and role: print(f'Getting credentials for profile {profile} (arn:aws:iam::{account}:role/{role})') else: print(f'Getting credentials for profile {profile}') return process(profile, account, role, extra) def set_credentials(profile, credentials): if not aws_credentials.has_section(profile): aws_credentials.add_section(profile) aws_credentials.set(profile, 'aws_access_key_id', credentials['accessKeyId']) aws_credentials.set(profile, 'aws_secret_access_key', credentials['secretAccessKey']) aws_credentials.set(profile, 'aws_session_token', credentials['sessionToken']) aws_credentials.set(profile, 'expires', credentials['expiration']) print(f'Saved credentials for profile {profile}') def process_profile(params): profile, credentials = get_credentials(*params) set_credentials(profile, credentials) aws_config = get_config(aws_config_filename) aws_credentials = get_config(aws_credentials_filename) profiles = (get_profile(s) for s in aws_config.sections()) profiles = [p for p in profiles if p] filtered_profiles = [p for p in profiles if not profile_filter or p[1] in profile_filter] source_profile_names = [p[4][0] for p in filtered_profiles if p[0] == get_sts_credentials] source_profiles = [p for p in profiles if p[1] in source_profile_names] filtered_profiles = [p for p in filtered_profiles if p not in source_profiles] for profile in source_profiles: process_profile(profile) with open(aws_credentials_filename, 'w') as credentials_file: aws_credentials.write(credentials_file) for profile in filtered_profiles: process_profile(profile) with open(aws_credentials_filename, 'w') as credentials_file: aws_credentials.write(credentials_file)