# Data & Model Preparation
This notebook will prepare the dataset and model for the module evaluation lab.  This is an optional step if you have kept your artifacts from previous modules.

## Import modules and initialize parameters for this notebook

In [None]:
import sagemaker
from sagemaker import get_execution_role
import glob
import random
import shutil
import os

role = get_execution_role()
sess = sagemaker.Session()

account = sess.account_id()
region = sess.boto_region_name
bucket = sess.default_bucket() # or use your own custom bucket name
prefix = 'BIRD-Sagemaker-Deployment'

## Dataset
The dataset we are using is from [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset contains 11,788 images across 200 bird species (the original technical report can be found here). Each species comes with around 60 images, with a typical size of about 350 pixels by 500 pixels. Bounding boxes are provided, as are annotations of bird parts. A recommended train/test split is given, but image size data is not.

Run the cell below to download the full dataset or download manually [here](https://course.fast.ai/datasets). Note that the file size is around 1.2 GB, and can take a while to download. If you plan to complete the entire workshop, please keep the file to avoid re-download and re-process the data.

In [None]:
!wget 'https://s3.amazonaws.com/fast-ai-imageclas/CUB_200_2011.tgz'
!tar xopf CUB_200_2011.tgz
!rm CUB_200_2011.tgz

Generate test samples for this lab

In [None]:
img_array = []
image_folder = 'CUB_200_2011/images'
dst = 'build/image_classification/images'

if not os.path.exists(dst):
    os.makedirs(dst)
    print("make new directory.....")

for sub_dir in (glob.glob(f'{image_folder}/*')):
    for filename in (glob.glob(f'{sub_dir}/*')):
        img_array.append(filename)

for i in range(10):
    rand_index = random.randint(0,len(img_array)-1)
    shutil.copy(img_array[rand_index], dst)

Copy data to s3

In [None]:
s3_raw_data = f's3://{bucket}/{prefix}/full/data'
!aws s3 cp --recursive ./CUB_200_2011 $s3_raw_data

In [None]:
!rm -rf ./CUB_200_2011
!rm -rf attributes.txt

In [None]:
from sagemaker.sklearn.processing import SKLearnProcessor

from sagemaker.processing import (
    ProcessingInput,
    ProcessingOutput,
)
import time 

timpstamp = str(time.time()).split('.')[0]
# SKlearnProcessor for preprocessing
output_prefix = f'{prefix}/outputs'
output_s3_uri = f's3://{bucket}/{output_prefix}'

class_selection = '13, 17, 35, 36, 47, 68, 73, 87'
input_annotation = 'classes.txt'
processing_instance_type = "ml.m5.xlarge"
processing_instance_count = 1

sklearn_processor = SKLearnProcessor(base_job_name = f"{prefix}-preprocess",  # choose any name
                                    framework_version='0.20.0',
                                    role=role,
                                    instance_type=processing_instance_type,
                                    instance_count=processing_instance_count)

In [None]:
sklearn_processor.run(
    code='../02_preprocessing/preprocessing.py',
    arguments=["--classes", class_selection, 
               "--input-data", input_annotation],
    inputs=[ProcessingInput(source=s3_raw_data, 
            destination="/opt/ml/processing/input")],
    outputs=[
            ProcessingOutput(source="/opt/ml/processing/output/train", destination = output_s3_uri +'/train'),
            ProcessingOutput(source="/opt/ml/processing/output/valid", destination = output_s3_uri +'/valid'),
            ProcessingOutput(source="/opt/ml/processing/output/test", destination = output_s3_uri +'/test'),
            ProcessingOutput(source="/opt/ml/processing/output/manifest", destination = output_s3_uri +'/manifest'),
        ],
    )

This is where your images and annotation files are located.  You will need these for this module.

In [None]:
print(f"Test dataset located here: {output_s3_uri +'/test'} ===========")

print(f"Test annotation file is located here: {output_s3_uri +'/manifest'} ===========")

In [None]:
from sagemaker.inputs import TrainingInput
from sagemaker.workflow.steps import TrainingStep
from sagemaker.tensorflow import TensorFlow

TF_FRAMEWORK_VERSION = '2.4.1'

hyperparameters = {'initial_epochs':     5,
                   'batch_size':         8,
                   'fine_tuning_epochs': 20, 
                   'dropout':            0.4,
                   'data_dir':           '/opt/ml/input/data'}

metric_definitions = [{'Name': 'loss',      'Regex': 'loss: ([0-9\\.]+)'},
                  {'Name': 'acc',       'Regex': 'accuracy: ([0-9\\.]+)'},
                  {'Name': 'val_loss',  'Regex': 'val_loss: ([0-9\\.]+)'},
                  {'Name': 'val_acc',   'Regex': 'val_accuracy: ([0-9\\.]+)'}]


distribution = {'parameter_server': {'enabled': False}}
DISTRIBUTION_MODE = 'FullyReplicated'
    
train_in = TrainingInput(s3_data=output_s3_uri +'/train', distribution=DISTRIBUTION_MODE)
val_in   = TrainingInput(s3_data=output_s3_uri +'/valid', distribution=DISTRIBUTION_MODE)
test_in  = TrainingInput(s3_data=output_s3_uri +'/test', distribution=DISTRIBUTION_MODE)

inputs = {'train':train_in, 'test': test_in, 'validation': val_in}

training_instance_type = 'ml.c5.4xlarge'

training_instance_count = 1

In [None]:
model_path = f"s3://{bucket}/{prefix}"

estimator = TensorFlow(entry_point='train-mobilenet.py',
               source_dir='../05_model_evaluation_and_model_explainability/05a_model_evaluation/code',
               output_path=model_path,
               instance_type=training_instance_type,
               instance_count=training_instance_count,
               distribution=distribution,
               hyperparameters=hyperparameters,
               metric_definitions=metric_definitions,
               role=role,
               framework_version=TF_FRAMEWORK_VERSION, 
               py_version='py37',
               base_job_name=prefix,
               script_mode=True)

In [None]:
estimator.fit(inputs)

In [None]:
training_job_name = estimator.latest_training_job.name

print(f"model artifacts file is uploaded here: {model_path}/{training_job_name}/output ========")

