import os
import sys
import boto3
import botocore
import json
import awswrangler as wr
import tempfile
import time
import ast
from collections import Counter
from configparser import ConfigParser
from urllib.parse import urlparse, urlunparse
from awsglue.utils import getResolvedOptions

information_schema_name = "information_schema"
df_index = ['table_schema', 'table_name']

def get_config(s3_config_bucket,s3_config_file ):
    s3 = boto3.client('s3')
    body_content = s3.get_object(Bucket=s3_config_bucket, Key=s3_config_file)['Body'].read().decode('utf-8')
    config = ConfigParser()
    config.read_string(body_content)
    #config.read("glue_config.conf")
    return config


def get_client(region_name, service):
    session = boto3.Session(region_name=region_name)
    return session.client(service)


aws_region_list = ['us-east-2','us-east-1','us-west-1','us-west-2','af-south-1','ap-east-1','ap-south-1','ap-northeast-3'
    ,'ap-northeast-2','ap-southeast-1','ap-southeast-2','ap-northeast-1','ca-central-1','eu-central-1','eu-west-1','eu-west-2'
    ,'eu-south-1','eu-west-3','eu-north-1','me-south-1','sa-east-1']

def compare_df(source_df, target_df, df_index):
    source_df.set_index(df_index)
    target_df.set_index(df_index)
    merged_df = source_df.merge(target_df, how='outer', indicator=True)
    return merged_df[merged_df['_merge'] == 'both'], merged_df[merged_df['_merge'] == 'left_only'], merged_df[merged_df['_merge'] == 'right_only']

def get_database_name(row):
    resource_name = list(row.get('Resource', {}))
    if resource_name == ['Database']:
        return row.get('Resource', {}).get('Database', {}).get('Name', None)
    elif resource_name == ['Table']:
        return row.get('Resource', {}).get('Table', {}).get('DatabaseName', None)
    elif resource_name == ['TableWithColumns']:
        return row.get('Resource', {}).get('TableWithColumns', {}).get('DatabaseName', None)
    else:
        return None

def store_permission_data(permission_data,permissions_from_region,lf_storage_bucket,lf_storage_folder,lf_storage_file_name):
    f = tempfile.TemporaryFile(mode='w+')
    for pd in permission_data:
        f.write(json.dumps(pd) + "\n")
    rf = open(f.name,'r+b')    
    output_file_name = f"s3://{lf_storage_bucket}/{lf_storage_folder}/{permissions_from_region}/{lf_storage_file_name}"
    print (f"Writing to output file name {output_file_name}")
    wr.s3.upload(local_file=rf, path=output_file_name)
    f.close()

def apply_table_permissions(file_location, destination_client, db_list,source_region,lf_storage_bucket,lf_storage_folder,lf_storage_file_name):
    print ("Reading permissions from s3 location")
    f = tempfile.TemporaryFile(mode='w+b')
    s3_client = get_client(source_region, 's3')
    s3_client.download_fileobj(f"{lf_storage_bucket}", f"{lf_storage_folder}/{source_region}/{lf_storage_file_name}", f)    
    print (f"{lf_storage_bucket}/{lf_storage_folder}/{source_region}/{lf_storage_file_name}")
    f.seek(0)
    rf = open(f.name,'r+')
    for r_row in rf.readlines():
        row  = json.loads(r_row)
        database_name = get_database_name(row)
        if database_name in db_list:
            print (f"Applying {row}")
            if 'Table' in row['Resource'] and 'Name' in row['Resource']['Table'] and row['Resource']['Table']['Name'] == 'ALL_TABLES':
                del row['Resource']['Table']['Name']
                row['Resource']['Table']['TableWildcard'] = {}
            if 'TableWithColumns' in row['Resource'] and 'Name' in row['Resource']['TableWithColumns'] and row['Resource']['TableWithColumns']['Name'] == 'ALL_TABLES':
                row['Resource']['Table'] = row['Resource']['TableWithColumns']
                del row['Resource']['Table']['Name']
                del row['Resource']['Table']['ColumnWildcard']
                del row['Resource']['TableWithColumns']
                row['Resource']['Table']['TableWildcard'] = {}
            response = destination_client.grant_permissions(**row)
    f.close()
    print ("Done applying table permissions")

def get_permissions(source_client):
    print("Processing permissions")
    result = source_client.list_permissions()
    principal_permissions = result['PrincipalResourcePermissions']
    fetch = True
    while fetch:
        try:
            token = result['NextToken']
            result = source_client.list_permissions(NextToken=token)
            principal_permissions.extend(result['PrincipalResourcePermissions'])
        except KeyError:
            fetch = False
    return principal_permissions


def create_table(glue_client, db_name, table):
    try:
        glue_client.create_table(DatabaseName=db_name, TableInput=table)
    except glue_client.exceptions.AlreadyExistsException:
        glue_client.update_table(DatabaseName=db_name, TableInput=table)

def create_database(glue_client,database_input):
    try:
        res = glue_client.create_database(DatabaseInput=database_input)
    except glue_client.exceptions.AlreadyExistsException:
        res = glue_client.update_database(DatabaseInput=database_input, Name=database_input['Name'])


def restore_data(config, data_source, glue_client,update_table_s3_location, table_s3_mapping):
    print ("Restoring database...")
    #glue_client = get_glue_client(config[data_source]['destination_region'])
    database_count = Counter()
    table_count = Counter()
    s3_path = config[data_source]['s3_data_path']
    f = tempfile.TemporaryFile(mode='w+b')
    wr.s3.download(path=s3_path, local_file=f)
    f.seek(0)
    rf = open(f.name, "r+t")
    for table_data in rf.readlines():
        print(table_data)
        object_type, db_name, object_name, object_data = table_data.split("\t")
        #print(object_type, db_name, object_name)
        if object_type == 'database':
            database_data = json.loads(object_data)
            print(f"Restoring database {object_name} json data => {database_data}")
            database_s3_location_target = None
            if update_table_s3_location:
                if 'LocationUri' in database_data:
                    database_s3_location = database_data.get('LocationUri')
                    if database_s3_location is not None:
                        u = urlparse(database_s3_location)
                        print (f"Received table_s3_mapping {table_s3_mapping}")
                        print (f"Received type table_s3_mapping {type(table_s3_mapping)}")
                        if u.netloc in table_s3_mapping:
                            target_s3_location = table_s3_mapping[u.netloc]
                            u = u._replace(netloc = target_s3_location)
                            print (u)
                            database_s3_location_target = urlunparse(u)
                        else:
                            database_s3_location_target = database_s3_location
                    database_data['LocationUri'] = database_s3_location_target
            create_database(glue_client, database_data)
            database_count[db_name] += 1
        elif object_type == 'table':
            table_data = json.loads(object_data)
            print(f"Restoring table {object_name} json data => {table_data} ")
            table_s3_location_target = None
            if update_table_s3_location:
                if 'Location' in database_data:
                    table_s3_location = database_data.get('Location', None)
                    if table_s3_location is not None:
                        u = urlparse(database_s3_location)
                        if u.netloc in table_s3_mapping:
                            target_s3_location = table_s3_mapping[u.netloc]
                            u = u._replace(netloc=target_s3_location)
                            table_s3_location_target = urlunparse(u)
                        else:
                            table_s3_location_target = table_s3_location
                    table_data['Location'] = table_s3_location_target
            create_table(glue_client, db_name, table_data)
            table_count[db_name] += 1
    f.close()        
    for db_name in database_count.keys():
        print(f"{db_name}=>table_count:{table_count[db_name]}")
    print(f"Restored database count => {len(list(database_count.keys()))}  table count => {len(list(table_count.elements()))}")

def get_tables(source_region, data_source, db_list):
    session_region = boto3.Session(region_name=source_region)
    db_list_string = "','".join(db_list)
    athena_query = f"""SELECT table_schema, table_name FROM  information_schema.tables
                       where table_schema in ('{db_list_string}') and table_catalog = lower('{data_source}')"""
    print (f"Running query {athena_query}")
    df = wr.athena.read_sql_query(athena_query ,database=information_schema_name, ctas_approach=False, boto3_session=session_region)
    return df

def extract_database(source_region, output_file_name, db_list):
    print ("Extracting database...")
    table_count = Counter()
    database_count = Counter()
    glue_client = get_client(source_region,'glue')
    table_paginator = glue_client.get_paginator("get_tables")
    db_paginator = glue_client.get_paginator("get_databases")
    for page in db_paginator.paginate():
        database_data_file = tempfile.TemporaryFile(mode='w+')
        for db in page['DatabaseList']:
            if (db_list == ['ALL_DATABASE'] or (db['Name'] in db_list)):
                print (f"Database {db['Name']} matched with list of databases to be extracted")
                col_to_be_removed = ['CreateTime', 'CatalogId','VersionId']
                _db = [db.pop(key, '') for key in col_to_be_removed]
                database_data_file.write(f"database\t{db['Name']}\t\t{json.dumps(db)}\n")
                database_count[db['Name']] += 1
                for page in table_paginator.paginate(DatabaseName=db['Name']):
                    for table in page['TableList']:
                        print(f"Processing table {table['Name']}")
                        col_to_be_removed = ['CatalogId','DatabaseName','LastAccessTime', 'CreateTime', 'UpdateTime', 'CreatedBy','IsRegisteredWithLakeFormation','VersionId']
                        _table = [table.pop(key,'') for key in col_to_be_removed]
                        database_data_file.write(f"table\t{db['Name']}\t{table['Name']}\t{json.dumps(table)}\n")
                        table_count[db['Name']] += 1
        for db_name in table_count.keys():
            print(f"{db_name}=>table_count:{table_count[db_name]}")
        database_data_file.seek(0)
        with open(database_data_file.name, 'rb') as rf:
            wr.s3.upload(local_file=rf, path=output_file_name)
            print(f"Stored data in database_data_file.name {database_data_file.name}")
            print(f"Output_file_name {output_file_name}")
    print(f"Extracted database count => {len(list(database_count.keys()))}  total table count => {len(list(table_count.elements()))}")    


def compare_db_tables(config, data_source):
    source_region = config[data_source]['source_region']
    destination_region = config[data_source]['destination_region']
    output_file_name = config[data_source]['s3_data_path']
    db_list = ast.literal_eval(config[data_source]['database_list'])
    start_time = time.time()
    print(f"Starting processing at {time.asctime(time.localtime(time.time()))} with the following parameters ")
    print(f"datasource => {data_source}")
    print(f"source region name => {source_region}")
    print(f"database => {db_list}")
    print(f"output_file_name => {output_file_name}")
    print("=============================================================================")
    source_tables_df = get_tables(source_region, data_source, db_list)
    destination_tables_df = get_tables(destination_region, data_source, db_list)
    matched_df, source_only_df, target_only_df = compare_df(source_tables_df, destination_tables_df, df_index)
    print("=" * 50)
    if not matched_df.dropna().empty:
        print("Matched Tables")
        print(matched_df[df_index].to_string(index=False))
    else:
        print ("No tables are matched")
    print("-" * 50)
    if not source_only_df.dropna().empty:
        print("Source only tables")
        print(source_only_df[df_index].to_string(index=False)) if not source_only_df.empty else print(
        "All tables copied, no tables in source left")
    else:
        print ("No tables are found in the source")
    print("-" * 50)
    if not target_only_df.dropna().empty:
        print("Target only tables")
        print(target_only_df[df_index].to_string(index=False)) if not target_only_df.empty else print(
        "All tables copied, no addiitonal tables in  tables left")
    else:
        print ("No tables are found in the target region")
    print("-" * 50)

    return matched_df, source_only_df, target_only_df

def delete_target_tables(config, data_source):
    matched_df, source_only_df, target_only_df = compare_db_tables(config, data_source)
    tables_to_deleted = target_only_df
    #TODO  : delete tables

def main():
    global source_session
    global destination_session
    start_time = time.time()
    
    args = getResolvedOptions(sys.argv, ['CONFIG_BUCKET','CONFIG_FILE_KEY'])

    config_file_bucket = args['CONFIG_BUCKET']
    config_file_key = args['CONFIG_FILE_KEY']

    ## if there is / in the key name, this will remove it.
    config_file_key = config_file_key.lstrip("/") if config_file_key.startswith("/") else config_file_key
    print (f"Reading config file from s3://{config_file_bucket}/{config_file_key}")
    config = get_config(config_file_bucket,config_file_key)
    delete_target_catalog_objects = config.getboolean('Operation', 'delete_target_catalog_objects')
    sync_glue_catalog = config.getboolean('Operation','sync_glue_catalog')
    sync_lf_permissions = config.getboolean('Operation', 'sync_lf_permissions')
    update_table_s3_location = config.getboolean('Target_s3_update','update_table_s3_location')
    table_s3_mapping = ast.literal_eval(config.get('AwsDataCatalog','target_s3_locations'))
    list_datasource = ast.literal_eval(config.get('ListCatalog','list_datasource'))

    print(f"Received list of data sources {list_datasource}")
    source_lf_client = get_client(config['AwsDataCatalog']['source_region'],'lakeformation')
    destination_lf_client = get_client(config['AwsDataCatalog']['destination_region'],'lakeformation')
    glue_client = get_client(config['AwsDataCatalog']['destination_region'], 'glue')

    for data_source in list_datasource:
        source_region = config['AwsDataCatalog']['source_region']
        output_file_name = config[data_source]['s3_data_path']
        target_region = config['AwsDataCatalog']['destination_region']
        lf_storage_bucket = config['LakeFormationPermissions']['lf_storage_bucket']
        lf_storage_folder = config['LakeFormationPermissions']['lf_storage_file_folder']
        lf_storage_file_name = config['LakeFormationPermissions']['lf_storage_file_name']

        db_list = ast.literal_eval(config[data_source]['database_list'])

        if sync_glue_catalog:
            print(f"Starting processing at {time.asctime(time.localtime(time.time()))} with the following parameters ")
            print(f"datasource => {data_source}")
            print(f"source region name => {source_region}")
            print(f"database => {db_list}")
            print(f"output_file_name => {output_file_name}")
            print("=============================Starting Processing ================================================")
            extract_database(source_region, output_file_name, db_list)
            restore_data(config, data_source, glue_client,update_table_s3_location, table_s3_mapping)

        if delete_target_catalog_objects:
            delete_target_tables(config, data_source)


        if sync_lf_permissions:
            permission_data = get_permissions(source_lf_client)
            store_permission_data(permission_data,source_region,lf_storage_bucket,lf_storage_folder,lf_storage_file_name)
            permission_data = get_permissions(destination_lf_client)
            store_permission_data(permission_data,target_region,lf_storage_bucket,lf_storage_folder,lf_storage_file_name)
            apply_table_permissions(f"{config['LakeFormationPermissions']['lf_storage_file_name']}", destination_lf_client, db_list,source_region,lf_storage_bucket,lf_storage_folder,lf_storage_file_name)
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"Processing finished in {int(execution_time)} secs")
    print("=" * 50)

main()