{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Module 1. Data Augmentation\n", "---\n", "\n", "This notebook shows a representative image augmentation technique that increases the diversity of the training set by applying various transforms such as affine transform (rotate, shift, etc.) and blur using the `albumentations` library.\n", "\n", "- Very similar to PyTorch's torchvision (you can learn it in 5-10 minutes) \n", "- Documentation: https://albumentations.readthedocs.io/en/latest/\n", "\n", "This hands-on can be completed in about **10 minutes**. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# 1. Preparation\n", "---\n", "\n", "## Install and upgrade packages\n", "\n", "If you create a new jupyter notebook instance, change `install_needed = True` in the code cell below, run the code cell, and change `install_needed = False` when the kernel is restarted. You only need to do this once.\n", "\n", "

Note

\n", "The reason we limit the torch version to a specific version is to unify the torch version used for model training, torchscript conversion, and SageMaker Neo compilation. When compiling models, please keep in mind that versions should match whenever possible.

" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store -z\n", "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import sys\n", "import logging\n", "import IPython\n", "import importlib\n", "\n", "install_needed = False\n", "#install_needed = False\n", "\n", "if install_needed:\n", " print(\"===> Installing deps and restarting kernel. Please change 'install_needed = False' and run this code cell again.\")\n", " is_torch = importlib.util.find_spec(\"torch\") \n", " found = is_torch is not None\n", " !{sys.executable} -m pip install -U torch==1.8.1 torchvision==0.9.1 fastai==2.5.3 \"opencv-python-headless<4.3\"\n", " !{sys.executable} -m pip install -qU glob2 smdebug albumentations\n", " IPython.Application.instance().kernel.do_shutdown(True)\n", " \n", "logging.basicConfig(\n", " format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',\n", " level=logging.INFO,\n", " datefmt='%Y-%m-%d %H:%M:%S',\n", " stream=sys.stdout,\n", ")\n", "\n", "logger = logging.getLogger() " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import glob2\n", "import cv2\n", "import numpy as np\n", "import albumentations as A\n", "from matplotlib import pyplot as plt\n", "\n", "raw_dir = 'raw'\n", "dataset_dir = 'bioplus'\n", "classes = os.listdir(raw_dir)\n", "num_classes = len(classes)\n", "train_size = 0.8\n", "num_augmentations = 5\n", "!rm -rf {dataset_dir}\n", "print(classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# 2. Data Augmentation\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _get_transforms_augmentation(cropsize_dim, resize_dim=500):\n", " \"\"\"\n", " Declare an augmentation pipeline\n", " \"\"\"\n", " transforms = A.Compose([\n", " A.CenterCrop(cropsize_dim, cropsize_dim),\n", " A.Resize(resize_dim, resize_dim),\n", " A.GaussNoise(p=0.4),\n", " A.RandomBrightnessContrast(p=0.2),\n", " A.OneOf([\n", " A.HorizontalFlip(p=0.5),\n", " A.RandomRotate90(p=0.5),\n", " A.VerticalFlip(p=0.5) \n", " ], p=0.2), \n", " A.OneOf([\n", " A.MotionBlur(p=.2),\n", " A.MedianBlur(blur_limit=3, p=0.1),\n", " A.Blur(blur_limit=3, p=0.1),\n", " ], p=0.3), \n", " A.OneOf([\n", " A.CLAHE(clip_limit=2),\n", " A.Sharpen(),\n", " A.HueSaturationValue(p=0.3), \n", " ], p=0.3),\n", " A.OneOf([\n", " A.Rotate(10, p=0.6),\n", " A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=10, p=0.4),\n", " ], p=0.3),\n", " ], p=1.0)\n", " return transforms\n", "\n", "\n", "def _make_augmented_images(f, write_path, phase, num_augmentations=10):\n", " \"\"\"\n", " Artificially augment raw image data. If you do not have enough raw data, you can take advantage of it.\n", " \"\"\" \n", " image = cv2.imread(f)\n", " \n", " h, w, c = image.shape\n", " cropsize_dim = np.min([h,w])\n", "\n", " filename = f.split('/')[-1]\n", " filename_noext = filename.split('.')[0]\n", " logger.info(f'[{phase}] Augmenting image: {filename}')\n", " \n", " for k in range(num_augmentations):\n", " transforms = _get_transforms_augmentation(cropsize_dim=cropsize_dim)\n", " transformed = transforms(image=image)\n", " transformed_image = transformed[\"image\"]\n", " cv2.imwrite(os.path.join(write_path, f'{filename_noext}_aug_{k:05d}.jpg'), transformed_image) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for c in classes:\n", "\n", " img_raw_path = os.path.join(raw_dir, c)\n", " img_train_path = os.path.join(dataset_dir, 'train', c)\n", " img_valid_path = os.path.join(dataset_dir, 'valid', c)\n", "\n", " os.makedirs(img_train_path, exist_ok=True)\n", " os.makedirs(img_valid_path, exist_ok=True)\n", "\n", " files = (glob2.glob(f\"{img_raw_path}/*.jpg\"))\n", " num_files = len(files)\n", " num_train_files = int(num_files * train_size)\n", "\n", " logger.info('-' * 70) \n", " logger.info(f'Augmenting class: {c}')\n", " logger.info(f'img_train_path: {img_train_path}')\n", " logger.info(f'img_valid_path: {img_valid_path}')\n", " logger.info(f'num_raw_files={num_files}, num_raw_train_files={num_train_files}')\n", " logger.info('-' * 70)\n", "\n", " # training images\n", " for f in files[:num_train_files]:\n", " _make_augmented_images(f, img_train_path, 'train', num_augmentations)\n", "\n", " # validation images\n", " for f in files[num_train_files:]:\n", " _make_augmented_images(f, img_valid_path, 'valid', num_augmentations)\n", " \n", " logger.info('')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Copy data to S3\n", "\n", "Copy data to S3. We are copying the raw image as it is, but try converting it to a file such as TFRecord or RecordIO in the future for more efficient training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "bucket = sagemaker.Session().default_bucket()\n", "s3_path = f's3://{bucket}/{dataset_dir}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "!aws s3 cp {dataset_dir} s3://{bucket}/{dataset_dir} --recursive --quiet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Store Class map as JSON\n", "\n", "Store the class dictionary as json. This file will be useful for model inference in the future." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import src.train_utils as train_utils\n", "classes, classes_dict = train_utils.get_classes(f'./{dataset_dir}/train') \n", "train_utils.save_classes_dict(classes_dict, 'classes_dict.json')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store bucket dataset_dir raw_dir classes num_classes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# Next Step\n", "\n", "With the training data ready, it is now time to develop and train the model. If you are unfamiliar with PyTorch, please proceed to `2_local_training.ipynb` first. If you are somewhat familiar with PyTorch, skip `2_local_training.ipynb` and proceed directly to `3_sm_training.ipynb`" ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 4 }