#!/usr/bin/env python3


import configparser
import datetime
import json
import os
import pathlib
import subprocess
import sys
import time

import boto3


profile_filter = set(sys.argv[1:])
profile_variable = os.environ.get('AWS_PROFILE')
if profile_variable:
    profile_filter.add(profile_variable)


home_folder = pathlib.Path.home()
cli_cache_folder = f'{home_folder}/.aws/cli/cache'
aws_config_filename = f'{home_folder}/.aws/config'
aws_credentials_filename = f'{home_folder}/.aws/credentials'


def get_config(config_filename):
    config = configparser.ConfigParser()
    config.read(config_filename)
    return config


def get_profile(section):
    if 'profile' not in section:
        return
    profile = section.replace('profile ', '')
    azure_tenant_id = aws_config.get(section, 'azure_tenant_id', fallback=None)
    okta_profile = aws_config.get(section, 'okta_profile', fallback=None)
    okta_account_id = aws_config.get(section, 'okta_account_id', fallback=None)
    okta_role_name = aws_config.get(section, 'okta_role_name', fallback=None)
    sso_account_id = aws_config.get(section, 'sso_account_id', fallback=None)
    sso_role_name = aws_config.get(section, 'sso_role_name', fallback=None)
    sso_start_url = aws_config.get(section, 'sso_start_url', fallback=None)
    source_profile = aws_config.get(section, 'source_profile', fallback=None)
    role_arn = aws_config.get(section, 'role_arn', fallback=None)
    if azure_tenant_id:
        return get_azure_credentials, profile, azure_tenant_id
    elif okta_profile and okta_account_id and okta_role_name:
        return get_okta_credentials, profile, okta_account_id, okta_role_name, okta_profile
    elif sso_account_id and sso_role_name and sso_start_url:
        return get_sso_credentials, profile, sso_account_id, sso_role_name
    elif source_profile and role_arn:
        arn_components = role_arn.split(':')
        account_id = arn_components[4]
        role_name = arn_components[5].replace('role/', '')
        extra = source_profile, role_arn
        return get_sts_credentials, profile, account_id, role_name, extra
    else:
        return get_vault_credentials, profile


def run_sso_refresh_command(profile):
    command = ('aws', 'sts', '--profile', profile, 'get-caller-identity')
    subprocess.run(command, stdout=subprocess.PIPE, check=True)


def run_sso_login_command(profile):
    command = ('aws', 'sso', '--profile', profile, 'login')
    subprocess.run(command, stdout=subprocess.PIPE, check=True)


def get_azure_credentials(profile, account, role, extra):
    command = ('aws-azure-login', '--force-refresh', '--no-prompt')
    subprocess.run(command, check=True)
    new_aws_credentials = get_config(aws_credentials_filename)
    credentials = {
        'accessKeyId': new_aws_credentials.get(profile, 'aws_access_key_id'),
        'secretAccessKey': new_aws_credentials.get(profile, 'aws_secret_access_key'),
        'sessionToken': new_aws_credentials.get(profile, 'aws_session_token'),
        'expiration': new_aws_credentials.get(profile, 'aws_expiration'),
    }
    return profile, credentials


def get_okta_credentials(profile, account, role, extra):
    okta_profile = extra
    role_arn = f'arn:aws:iam::{account}:role/{role}'
    command = ('gimme-aws-creds', '--profile', okta_profile, '--roles', role_arn)
    output = subprocess.run(command, stdout=subprocess.PIPE, check=True)
    output_credentials = json.loads(output.stdout)['credentials']
    expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
    credentials = {
        'accessKeyId': output_credentials['aws_access_key_id'],
        'secretAccessKey': output_credentials['aws_secret_access_key'],
        'sessionToken': output_credentials['aws_session_token'],
        'expiration': expiration.isoformat() + 'Z',
    }
    return profile, credentials


def get_sso_credentials(profile, account, role, extra):
    try:
        run_sso_refresh_command(profile)
    except subprocess.CalledProcessError:
        run_sso_login_command(profile)
        run_sso_refresh_command(profile)
    time.sleep(2.0)
    cache_filenames = (os.path.join(cli_cache_folder, f) for f in os.listdir(cli_cache_folder))
    cache_credentials = [json.load(open(f)) for f in cache_filenames]
    newest_credentials = sorted(cache_credentials, key=lambda c: c['Credentials']['Expiration'])[-1]
    credentials = {
        'accessKeyId': newest_credentials['Credentials']['AccessKeyId'],
        'secretAccessKey': newest_credentials['Credentials']['SecretAccessKey'],
        'sessionToken': newest_credentials['Credentials']['SessionToken'],
        'expiration': newest_credentials['Credentials']['Expiration'].replace('UTC', 'Z'),
    }
    return profile, credentials


def get_sts_credentials(profile, account, role, extra):
    source_profile, role_arn = extra
    session = boto3.session.Session(profile_name=source_profile)
    sts = session.client('sts')
    response = sts.assume_role(RoleArn=role_arn, RoleSessionName='aws-assume-role')
    credentials = {
        'accessKeyId': response['Credentials']['AccessKeyId'],
        'secretAccessKey': response['Credentials']['SecretAccessKey'],
        'sessionToken': response['Credentials']['SessionToken'],
        'expiration': response['Credentials']['Expiration'].isoformat() + 'Z',
    }
    return profile, credentials


def get_vault_credentials(profile, account, role, extra):
    command = ('aws-vault', 'exec', profile, 'env')
    vault_process = subprocess.run(command, stdout=subprocess.PIPE, encoding='utf-8')
    split_lines = (n.strip().split('=') for n in vault_process.stdout.split())
    variable_lines = ((n[0], '='.join(n[1:])) for n in split_lines)
    aws_variable_lines = (n for n in variable_lines if n and n[0].startswith('AWS'))
    variables = dict(aws_variable_lines)
    expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=12)
    credentials = {
        'accessKeyId': variables['AWS_ACCESS_KEY_ID'],
        'secretAccessKey': variables['AWS_SECRET_ACCESS_KEY'],
        'sessionToken': variables['AWS_SESSION_TOKEN'],
        'expiration': expiration.isoformat() + 'Z',
    }
    return profile, credentials


def get_credentials(process, profile, account=None, role=None, extra=None):
    if account and role:
        print(f'Getting credentials for profile {profile} (arn:aws:iam::{account}:role/{role})')
    else:
        print(f'Getting credentials for profile {profile}')
    return process(profile, account, role, extra)


def set_credentials(profile, credentials):
    if not aws_credentials.has_section(profile):
        aws_credentials.add_section(profile)
    aws_credentials.set(profile, 'aws_access_key_id', credentials['accessKeyId'])
    aws_credentials.set(profile, 'aws_secret_access_key', credentials['secretAccessKey'])
    aws_credentials.set(profile, 'aws_session_token', credentials['sessionToken'])
    aws_credentials.set(profile, 'expires', credentials['expiration'])
    print(f'Saved credentials for profile {profile}')


def process_profile(params):
    profile, credentials = get_credentials(*params)
    set_credentials(profile, credentials)


aws_config = get_config(aws_config_filename)
aws_credentials = get_config(aws_credentials_filename)


profiles = (get_profile(s) for s in aws_config.sections())
profiles = [p for p in profiles if p]

filtered_profiles = [p for p in profiles if not profile_filter or p[1] in profile_filter]
source_profile_names = [p[4][0] for p in filtered_profiles if p[0] == get_sts_credentials]
source_profiles = [p for p in profiles if p[1] in source_profile_names]
filtered_profiles = [p for p in filtered_profiles if p not in source_profiles]


for profile in source_profiles:
    process_profile(profile)

with open(aws_credentials_filename, 'w') as credentials_file:
    aws_credentials.write(credentials_file)

for profile in filtered_profiles:
    process_profile(profile)

with open(aws_credentials_filename, 'w') as credentials_file:
    aws_credentials.write(credentials_file)