# Module 1. Data Augmentation
---

This notebook shows a representative image augmentation technique that increases the diversity of the training set by applying various transforms such as affine transform (rotate, shift, etc.) and blur using the `albumentations` library.

- Very similar to PyTorch's torchvision (you can learn it in 5-10 minutes) 
- Documentation: https://albumentations.readthedocs.io/en/latest/

This hands-on can be completed in about **10 minutes**. 

<br>

# 1. Preparation
---

## Install and upgrade packages

If you create a new jupyter notebook instance, change `install_needed = True` in the code cell below, run the code cell, and change `install_needed = False` when the kernel is restarted. You only need to do this once.

<div class="alert alert-info"><h4>Note</h4><p>
The reason we limit the torch version to a specific version is to unify the torch version used for model training, torchscript conversion, and SageMaker Neo compilation. When compiling models, please keep in mind that versions should match whenever possible.</p></div>

In [None]:
%store -z
%load_ext autoreload
%autoreload 2
%matplotlib inline
import sys
import logging
import IPython
import importlib

install_needed = False
#install_needed = False

if install_needed:
    print("===> Installing deps and restarting kernel. Please change 'install_needed = False' and run this code cell again.")
    is_torch = importlib.util.find_spec("torch") 
    found = is_torch is not None
    !{sys.executable} -m pip install -U torch==1.8.1 torchvision==0.9.1 fastai==2.5.3 "opencv-python-headless<4.3"
    !{sys.executable} -m pip install -qU glob2 smdebug albumentations
    IPython.Application.instance().kernel.do_shutdown(True)
    
logging.basicConfig(
    format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S',
    stream=sys.stdout,
)

logger = logging.getLogger()    

In [None]:
import os
import glob2
import cv2
import numpy as np
import albumentations as A
from matplotlib import pyplot as plt

raw_dir = 'raw'
dataset_dir = 'bioplus'
classes = os.listdir(raw_dir)
num_classes = len(classes)
train_size = 0.8
num_augmentations = 5
!rm -rf {dataset_dir}
print(classes)

<br>

# 2. Data Augmentation
---


In [None]:
def _get_transforms_augmentation(cropsize_dim, resize_dim=500):
    """
    Declare an augmentation pipeline
    """
    transforms = A.Compose([
        A.CenterCrop(cropsize_dim, cropsize_dim),
        A.Resize(resize_dim, resize_dim),
        A.GaussNoise(p=0.4),
        A.RandomBrightnessContrast(p=0.2),
        A.OneOf([
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.VerticalFlip(p=0.5)           
        ], p=0.2),   
        A.OneOf([
            A.MotionBlur(p=.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.3),  
        A.OneOf([
            A.CLAHE(clip_limit=2),
            A.Sharpen(),
            A.HueSaturationValue(p=0.3),           
        ], p=0.3),
        A.OneOf([
            A.Rotate(10, p=0.6),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=10, p=0.4),
        ], p=0.3),
    ], p=1.0)
    return transforms


def _make_augmented_images(f, write_path, phase, num_augmentations=10):
    """
    Artificially augment raw image data. If you do not have enough raw data, you can take advantage of it.
    """    
    image = cv2.imread(f)
    
    h, w, c = image.shape
    cropsize_dim = np.min([h,w])

    filename = f.split('/')[-1]
    filename_noext = filename.split('.')[0]
    logger.info(f'[{phase}] Augmenting image: {filename}')
    
    for k in range(num_augmentations):
        transforms = _get_transforms_augmentation(cropsize_dim=cropsize_dim)
        transformed = transforms(image=image)
        transformed_image = transformed["image"]
        cv2.imwrite(os.path.join(write_path, f'{filename_noext}_aug_{k:05d}.jpg'), transformed_image)  

In [None]:
for c in classes:

    img_raw_path = os.path.join(raw_dir, c)
    img_train_path = os.path.join(dataset_dir, 'train', c)
    img_valid_path = os.path.join(dataset_dir, 'valid', c)

    os.makedirs(img_train_path, exist_ok=True)
    os.makedirs(img_valid_path, exist_ok=True)

    files = (glob2.glob(f"{img_raw_path}/*.jpg"))
    num_files = len(files)
    num_train_files = int(num_files * train_size)

    logger.info('-' * 70)   
    logger.info(f'Augmenting class: {c}')
    logger.info(f'img_train_path: {img_train_path}')
    logger.info(f'img_valid_path: {img_valid_path}')
    logger.info(f'num_raw_files={num_files}, num_raw_train_files={num_train_files}')
    logger.info('-' * 70)

    # training images
    for f in files[:num_train_files]:
        _make_augmented_images(f, img_train_path, 'train', num_augmentations)

    # validation images
    for f in files[num_train_files:]:
        _make_augmented_images(f, img_valid_path, 'valid', num_augmentations)
    
    logger.info('')

## Copy data to S3

Copy data to S3. We are copying the raw image as it is, but try converting it to a file such as TFRecord or RecordIO in the future for more efficient training.

In [None]:
import sagemaker
bucket = sagemaker.Session().default_bucket()
s3_path = f's3://{bucket}/{dataset_dir}'

In [None]:
%%time
!aws s3 cp {dataset_dir} s3://{bucket}/{dataset_dir} --recursive --quiet

## Store Class map as JSON

Store the class dictionary as json. This file will be useful for model inference in the future.

In [None]:
import src.train_utils as train_utils
classes, classes_dict = train_utils.get_classes(f'./{dataset_dir}/train') 
train_utils.save_classes_dict(classes_dict, 'classes_dict.json')

In [None]:
%store bucket dataset_dir raw_dir classes num_classes

<br>

# Next Step

With the training data ready, it is now time to develop and train the model. If you are unfamiliar with PyTorch, please proceed to `2_local_training.ipynb` first. If you are somewhat familiar with PyTorch, skip `2_local_training.ipynb` and proceed directly to `3_sm_training.ipynb`