from train import train, parse_args

import sys
import os
import boto3
import json

dirname = os.path.dirname(os.path.abspath(__file__))

with open(os.path.join(dirname, "config.json"), "r") as f:
    CONFIG = json.load(f)
    
def download_from_s3(data_dir='/tmp/data', train=True):
    """Download MNIST dataset and convert it to numpy array
    Args:
        data_dir (str): directory to save the data
        train (bool): download training set
    Returns:
        tuple of images and labels as numpy arrays
    """
    
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    if train:
        images_file = "train-images-idx3-ubyte.gz"
        labels_file = "train-labels-idx1-ubyte.gz"
    else:
        images_file = "t10k-images-idx3-ubyte.gz"
        labels_file = "t10k-labels-idx1-ubyte.gz"

    # download objects
    s3 = boto3.client('s3')
    bucket = CONFIG["public_bucket"]
    for obj in [images_file, labels_file]:
        key = os.path.join("datasets/image/MNIST", obj)
        dest = os.path.join(data_dir, obj)
        if not os.path.exists(dest):
            s3.download_file(bucket, key, dest)
    return

class Env:
    def __init__(self):       
        # simulate container env
        os.environ["SM_MODEL_DIR"] = "/tmp/tf/model"
        os.environ["SM_CHANNEL_TRAINING"]="/tmp/data"
        os.environ["SM_CHANNEL_TESTING"]="/tmp/data"
        os.environ["SM_HOSTS"] = '["algo-1"]'
        os.environ["SM_CURRENT_HOST"]="algo-1"
        os.environ["SM_NUM_GPUS"] = "0"
     
    
if __name__=='__main__':
    Env()
    download_from_s3()
    download_from_s3(train=False)
    args = parse_args()
    train(args)