# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import torch
import tarfile
import boto3
import os
import json
import yaml

# first we need to retrieve the model pth file, for that let's consult the model package
model_package_arn = os.environ["MODEL_PACKAGE_ARN"]
deployment_bucket_name = os.environ['DEPLOYMENT_BUCKET_NAME']
region = os.environ["AWS_REGION"]

client_sm = boto3.client("sagemaker")

def update_component_config_in_json_file(filepath, version, component_name):
    with open(filepath+'gdk-config.json') as f:
        data = json.load(f)

        data['component'][component_name]['publish']['bucket'] = data['component'][component_name]['publish']['bucket'].replace('_BUCKET_NAME_', deployment_bucket_name)
        data['component'][component_name]['publish']['region'] = data['component'][component_name]['publish']['region'].replace('_REGION_', region)
        data['component'][component_name]['version'] = data['component'][component_name]['version'].replace('_COMPONENT_VERSION_', version)

        with open(filepath+'gdk-config.json', 'w') as f:
            json.dump(data, f)

def update_component_recipe_yaml_file(filepath, version, component_name, model_name):
    with open(filepath+"recipe.yaml", 'r') as stream:
        try:
            loaded = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    # Modify the fields from the dict
    if "detector" in filepath:
        loaded['ComponentConfiguration']['DefaultConfiguration']['model_name'] = model_name
        loaded['ComponentConfiguration']['DefaultConfiguration']['model_version'] = version
        
    loaded['ComponentName'] = component_name
    loaded['ComponentVersion'] = version
    loaded['ComponentPublisher'] = "Amazon.com"
    loaded['Manifests'][0]['Artifacts'][0]['URI'] = "s3://"+deployment_bucket_name+"/"+component_name+"/"+version+"/"+component_name+".zip"

    # Save it again
    with open(filepath+"recipe.yaml", 'w') as stream:
        try:
            yaml.dump(loaded, stream, default_flow_style=False)
        except yaml.YAMLError as exc:
            print(exc)

# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model_package
response = client_sm.describe_model_package(ModelPackageName=model_package_arn)

s3_model_location = response['InferenceSpecification']['Containers'][0]['ModelDataUrl']
model_package_version = response['ModelPackageVersion']

# let's pull the compressed model data from the bucket
client_s3 = boto3.client("s3")

bucket, key = s3_model_location.split('/',2)[-1].split('/',1)

try:
    client_s3.download_file(bucket, key, 'model.tar.gz')
except Exception as e:
    print(e)

# let's unzip the model package
file = tarfile.open('model.tar.gz')
file.extractall('.')
file.close()

# now load the model
pytorch_model = torch.load('model.pth',  map_location='cpu')

pytorch_model.eval() 
n_features=6
x = torch.rand(1,n_features,10,10).float()

input_names = [ "input"]
output_names = [ "output" ]

output_onnx_model_name = 'windturbine'
output_onnx_model = './aws.samples.windturbine.model/'+output_onnx_model_name+'.onnx'

torch.onnx.export(pytorch_model,
                 x,
                 output_onnx_model,
                 verbose=True,
                 input_names=input_names,
                 output_names=output_names,
                 export_params=True,
                 )

# Update the recipe/config for each component
component_version = '1.0.'+str(model_package_version)
update_component_config_in_json_file('./aws.samples.windturbine.model/', component_version, 'aws.samples.windturbine.model')
update_component_config_in_json_file('./aws.samples.windturbine.detector.venv/', component_version, 'aws.samples.windturbine.detector.venv')
update_component_config_in_json_file('./aws.samples.windturbine.detector/', component_version, 'aws.samples.windturbine.detector')

update_component_recipe_yaml_file('./aws.samples.windturbine.model/', component_version, 'aws.samples.windturbine.model', output_onnx_model_name)
update_component_recipe_yaml_file('./aws.samples.windturbine.detector/', component_version, 'aws.samples.windturbine.detector', output_onnx_model_name)