{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SageMakerCV PyTorch Tutorial\n",
"\n",
"SageMakerCV is a collection of computer vision tools developed to take full advantage of Amazon SageMaker by providing state of the art model accuracy, training speed, and training cost reductions. SageMakerCV is based on the lessons we learned from developing the record breaking computer vision models we announced at Re:Invent in 2019 and 2020, along with talking to our customers and understanding the challenges they faced in training their own computer vision models.\n",
"\n",
"The tutorial in this notebook walks through using SageMakerCV to train Mask RCNN on the COCO dataset. The only prerequisite is to setup SageMaker studio, the instructions for which can be found in [Onboard to Amazon SageMaker Studio Using Quick Start](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html). Everything else, from getting the COCO data to launching a distributed training cluster, is included here.\n",
"\n",
"## Setup and Roadmap\n",
"\n",
"Before diving into the tutorial itself, let's take a minute to discuss the various tools we'll be using.\n",
"\n",
"#### SageMaker Studio\n",
"[SageMaker Studio](https://aws.amazon.com/sagemaker/studio/) is a machine learning focused IDE where you can interactively develop models and launch SageMaker training jobs all in one place. SageMaker Studio provides a Jupyter Lab like environment, but with a number of enhancements. We'll just scratch the surface here. See the [SageMaker Studio Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html) for more details.\n",
"\n",
"For our purposes, the biggest difference from regular Jupyter Lab is that SageMaker Studio allows you to change your compute resources as needed, by connecting notebooks to Docker containers on different ML instances. This is a little confusing to just describe, so let's walk through an example.\n",
"\n",
"Once you've completed the setup on [Onboard to Amazon SageMaker Studio Using Quick Start](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html), go to the [SageMaker Console](https://us-west-2.console.aws.amazon.com/sagemaker) and click `Open SageMaker Studio` near the top right of the page.\n",
"\n",
"
\n",
"\n",
"If you haven't yet created a user, do so via the link at the top left of the page. Give it any name you like. For execution role, you can either use an existing SageMaker role, or create a new one. If you're unsure, create a new role. On the `Create IAM Role` window, make sure to select `Any S3 Bucket`. \n",
"\n",
"
\n",
"\n",
"Back on the SageMaker Studio page, select `Open Studio` next to the user you just created.\n",
"\n",
"
\n",
"\n",
"This will take a couple minutes to start up the first time. Once it starts, you'll have a Jupyter Lab like interface running on a small instance with an attached EBS volume. Let's start by taking a look at the `Launcher` tab.\n",
"\n",
"
\n",
"\n",
"If you don't see the `Launcher`, you can bring one up by clicking the `+` on the menu bar in the upper left corner.\n",
"\n",
"
\n",
"\n",
"The `Launcher` gives you access to all kinds of tools. This is where you can create new notebooks, text files, or get a terminal for your instance. Try the `System Terminal`. This gives you a new terminal tab for your Studio instance. It's useful for things like downloading data or cloning github repos into studio. For example, you can run `aws s3 ls` to browse your current S3 buckets. Go ahead and clone this repo onto Studio with \n",
"\n",
"`git clone https://github.com/aws-samples/amazon-sagemaker-cv`\n",
"\n",
"Let's look at the launcher one more time. Bring another one up with the `+`. Notice you have an option for `Select a SageMaker image` above the button to launch a notebook. This allows you to select a Docker image that will launch on a new instance. The notebook you create will be attached to that new instance, along with the EBS volume on your Studio instance. Let's try it out. On the `Launcher` page, click the drop down menu next to `Select a SageMaker Image` and select `PyTorch 1.6 Python 3.6 (Optimzed for GPU)`, then click the `Notebook` button below the dropdown.\n",
"\n",
"
\n",
"\n",
"Take a look at the upper righthand corner of the notebook. \n",
"\n",
"
\n",
"\n",
"The `Ptyhon 3 (PyTorch 1.6 Python 3.6 GPU Optimized)` refers to the kernel associated with this notebook. The `Unknown` refers to the current instance type. Click `Unknown` and select `ml.g4dn.xlarge`.\n",
"\n",
"
\n",
"\n",
"This will launch a `ml.g4dn.xlarge` instance and attach this notebook to it. This will take a couple of minutes, because Studio needs to download the PyTorch Docker image to the new instance. Once an instance has started, launching new notebooks with the same instance type and kernel is immediate. You'll also see the `Unknown` replaced with and instance description `4 vCPU + 16 GiB + 1 GPU`. You can also change instance as needed. Say you want to run your notebook on a `ml.p3dn.24xlarge` to get 8 GPUs. To change instances, just click the instance description. To get more instances in the menu, deselect `Fast launch only`.\n",
"\n",
"Once your notebook is up and running, you can also get a terminal into your new instance.\n",
"\n",
"
\n",
"\n",
"This can be useful for customizing your image with setup scripts, pip installing new packages, or using mpi to launch multi GPU training jobs. Click to get a terminal and run `ls`. Note that you have the same directories as your main Studio instance. Studio will attach the same EBS volume to all the instances you start, so all your files and data are shared across any notebooks you start. This means that you can prototype a model on a single GPU instance, then switch to a multi GPU instance while still having access to all of your data and scripts.\n",
"\n",
"Finally, when you want to shut down instances, click the circle with a square in it on the left hand side.\n",
"\n",
"
\n",
"\n",
"This shows your current running instances, and the Docker containers attached to those instances. To shut them down, just click the power button to their right.\n",
"\n",
"Now that we've explored studio a bit, let's get started with SageMakerCV. If you followed the instructions above to clone the repo, you should have `amazon-sagemaker-cv` in the file browser on the left. Navigate to `amazon-sagemaker-cv/pytorch/tutorial.ipynb` to open this notebook on your instance. If you still have a `g4dn` running, it should automatically attach to it.\n",
"\n",
"The rest of this notebook is broken into 4 sections.\n",
"\n",
"- Installing SageMakerCV and Downloading the COCO Data\n",
"\n",
"Since we're using the base AWS Deep Learning Container image, we need to add the SageMakerCV tools. Then we'll download the COCO dataset and upload it to S3.\n",
"\n",
"- Prototyping in Studio\n",
"\n",
"We'll walk through how to train a model on Studio, how SageMakerCV is structured, and how you can add your own models and features.\n",
"\n",
"- Launching a SageMaker Training Job\n",
"\n",
"There's lots of bells and whistles available to train your models fast, an on large datasets. We'll put a lot of those together to launch a high performance training job. Specifically, we'll create a training job with 4 P4d.24xlarge instances connected with 400 GB EFA, and streaming our training data from S3, so we don't have to load the dataset onto the instances before training. You could even use this same configuration to train on a dataset that wouldn't fit on the instances. If you'd rather only launch a smaller (or larger) training cluster, we'll discuss how to modify configuration.\n",
"\n",
"- Testing Our Model\n",
"\n",
"Finally, we'll take the output trained Mask RCNN model and visualize its performance in Studio.\n",
"\n",
"#### Installing SageMakerCV\n",
"\n",
"To install SageMakerCV on the PyTorch Studio Docker, just run `pip install -e .` in the `amazon-sagemaker-cv/pytorch` directory. You can do this with either an image terminal, or by running the paragraph below. Note that we use the `-e` option. This will keep the SageMakerCV modules editable, so any changes you make will be launched on your training job."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -e ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"### Setup on S3 and Download COCO data\n",
"\n",
"Next we need to setup an S3 bucket for all our data and results. Enter a name for your S3 bucket below. You can either create a new bucket, or use an existing bucket. If you use an existing bucket, make sure it's in the same region where you plan to run training. For new buckets, we'll specify that it needs to be in the current SageMaker region. By default we'll put everything in an S3 location on your bucket named `smcv-tutorial`, and locally in `/root/smcv-tutorial`, but you can change these locations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"S3_BUCKET = 'sagemaker-smcv-tutorial' # Don't include s3:// in your bucket name\n",
"S3_DIR = 'smcv-pytorch-tutorial'\n",
"LOCAL_DATA_DIR = '/root/smcv-pytorch-tutorial' #for reasons detailed in Destributed Training, do not put this dir in your source dir"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import zipfile\n",
"from pathlib import Path\n",
"from s3fs import S3FileSystem\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"import boto3\n",
"from botocore.client import ClientError\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s3 = boto3.resource('s3')\n",
"boto_session = boto3.session.Session()\n",
"region = boto_session.region_name\n",
"\n",
"# Check if bucket exists. If it doesn't, create it.\n",
"\n",
"try:\n",
" bucket = s3.meta.client.head_bucket(Bucket=S3_BUCKET)\n",
" print(f\"S3 Bucket {S3_BUCKET} Exists\")\n",
"except ClientError:\n",
" print(f\"Creating Bucket {S3_BUCKET}\")\n",
" bucket = s3.create_bucket(Bucket=S3_BUCKET, CreateBucketConfiguration={'LocationConstraint': region})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"Next we'll download the COCO data to Studio, unzip the files, and upload to S3. The reason we want the data in two places is that it's convenient to have the data locally on Studio for prototyping. We also want to unarchive the data before moving it to S3 so that we can stream it to our training instances instead of downloading it all at once.\n",
"\n",
"Once this is finished, you'll have copies of the COCO data on your Studio instance, and in S3. Be careful not to open the `data/coco/train2017` dir in the Studio file browser. It contains 118287 images, and can cause your web browser to crash. If you need to browse these files, use the terminal.\n",
"\n",
"This only needs to be done once, and only if you don't already have the data. The COCO 2017 dataset is about 20GB, so this step takes around 30 minutes to complete. The next paragraph sets up all the file directories we'll use for downloading, and later in training. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"COCO_URL=\"http://images.cocodataset.org\"\n",
"ANNOTATIONS_ZIP=\"annotations_trainval2017.zip\"\n",
"TRAIN_ZIP=\"train2017.zip\"\n",
"VAL_ZIP=\"val2017.zip\"\n",
"COCO_DIR=os.path.join(LOCAL_DATA_DIR, 'data', 'coco')\n",
"os.makedirs(COCO_DIR, exist_ok=True)\n",
"S3_DATA_LOCATION=os.path.join(\"s3://\", S3_BUCKET, S3_DIR, \"data\", \"coco\")\n",
"S3_WEIGHTS_LOCATION=os.path.join(\"s3://\", S3_BUCKET, S3_DIR, \"data\", \"weights\")\n",
"WEIGHTS_DIR=os.path.join(LOCAL_DATA_DIR, 'data', 'weights')\n",
"os.makedirs(WEIGHTS_DIR, exist_ok=True)\n",
"R50_WEIGHTS=\"resnet50.pkl\"\n",
"R50_WEIGHTS_SRC=\"https://sagemakercv.s3.us-west-2.amazonaws.com/weights/pytorch\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"And this paragraph will download everything, and take around 30 minutes to complete."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Downloading annotations\")\n",
"!wget -O $COCO_DIR/$ANNOTATIONS_ZIP $COCO_URL/annotations/$ANNOTATIONS_ZIP\n",
"!unzip $COCO_DIR/$ANNOTATIONS_ZIP -d $COCO_DIR\n",
"!rm $COCO_DIR/$ANNOTATIONS_ZIP\n",
"!aws s3 cp --recursive $COCO_DIR/annotations $S3_DATA_LOCATION/annotations\n",
"\n",
"print(\"Downloading COCO training data\")\n",
"!wget -O $COCO_DIR/$TRAIN_ZIP $COCO_URL/zips/$TRAIN_ZIP\n",
"\n",
"# train data has ~128000 images. Unzip is too slow, about 1.5 hours beceause of disk read and write speed on the EBS volume. \n",
"# This technique is much faster because it grabs all the zip metadata at once, then uses threading to unzip multiple files at once.\n",
"print(\"Unzipping COCO training data\")\n",
"train_zip = zipfile.ZipFile(os.path.join(COCO_DIR, TRAIN_ZIP))\n",
"jpeg_files = [image.filename for image in train_zip.filelist if image.filename.endswith('.jpg')]\n",
"os.makedirs(os.path.join(COCO_DIR, 'train2017'))\n",
"with ThreadPoolExecutor() as executor:\n",
" threads = list(tqdm(executor.map(lambda x: train_zip.extract(x, COCO_DIR), jpeg_files), total=len(jpeg_files)))\n",
"\n",
"# same issue for uploading to S3. this uploads in parallel, and is faster than using aws cli in this case.\n",
"print(\"Uploading COCO training data to S3\")\n",
"train_images = [i for i in Path(os.path.join(COCO_DIR, 'train2017')).glob(\"*.jpg\")]\n",
"s3fs = S3FileSystem()\n",
"with ThreadPoolExecutor() as executor:\n",
" threads = list(tqdm(executor.map(lambda image: s3fs.put(image.as_posix(), os.path.join(S3_DATA_LOCATION, 'train2017', image.name)), \n",
" train_images), total=len(train_images)))\n",
"# !rm $COCO_DIR/$TRAIN_ZIP\n",
"\n",
"print(\"Downloading COCO validation data\")\n",
"!wget -O $COCO_DIR/$VAL_ZIP $COCO_URL/zips/$VAL_ZIP\n",
"# switch to also threading\n",
"!unzip -q $COCO_DIR/$VAL_ZIP -d $COCO_DIR\n",
"val_images = [i for i in Path(os.path.join(COCO_DIR, 'val2017')).glob(\"*.jpg\")]\n",
"with ThreadPoolExecutor() as executor:\n",
" threads = list(tqdm(executor.map(lambda image: s3fs.put(image.as_posix(), os.path.join(S3_DATA_LOCATION, 'val2017', image.name)), \n",
" val_images), total=len(val_images)))\n",
"# !rm $COCO_DIR/$VAL_ZIP\n",
"\n",
"# grab resnet backbone from public S3 bucket\n",
"!wget -O $WEIGHTS_DIR/$R50_WEIGHTS $R50_WEIGHTS_SRC/$R50_WEIGHTS\n",
"!aws s3 cp $WEIGHTS_DIR/$R50_WEIGHTS $S3_WEIGHTS_LOCATION/$R50_WEIGHTS\n",
"\n",
"print(\"FINISHED!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"### Training on Studio\n",
"\n",
"Now that we have the data, we can get to training a Mask RCNN model to detect objects in the COCO dataset images. \n",
"\n",
"Since training on a single GPU can take days, we'll just train for a couple thousands steps, and run a single evaluation to make sure our model is at least starting to learn something. We'll train a full model on a larger cluster of GPUs in a SageMaker training job.\n",
"\n",
"The reason we first want to train in Studio is that we want to dig a bit into the SageMakerCV framework, and talk about the model architecture, since we expect many users will want to modify models for their own use cases.\n",
"\n",
"#### Mask RCNN\n",
"\n",
"First, just a very brief overview of Mask RCNN. If you would like a more in depth examination, we recommend taking a look at the [original paper](https://arxiv.org/abs/1703.06870), the [feature pyramid paper](https://arxiv.org/abs/1612.03144) which describes a popular architectural change we'll use in our model, and blog posts from [viso.ai](https://viso.ai/deep-learning/mask-r-cnn/), [tryo labs](https://tryolabs.com/blog/2018/01/18/faster-r-cnn-down-the-rabbit-hole-of-modern-object-detection/), [Jonathan Hui](https://jonathan-hui.medium.com/image-segmentation-with-mask-r-cnn-ebe6d793272), and [Lilian Weng](https://lilianweng.github.io/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html).\n",
"\n",
"Mask RCNN is a two stage object detection model that locates objects in images by places bounding boxes around, and segmentation masks over, any object for which the model is trained to find. It also provides classifcations for each object.\n",
"\n",
"
\n",
"\n",
"Mask RCNN is called a two stage model because it performs detection in two steps. The first identified any objects in the image, versus background. The second stage determines the specific class of each object, and applies the segmentation mask. Below is an architectural diagram of the model. Let's walk through each step.\n",
"\n",
"
\n",
"Credit: Jonathan Hui\n",
"\n",
"The `Convolution Network` is often referred to as the model backbone. This is a pretrained image classification model, commonly ResNet, which has been trained on a large image classification dataset, like ImageNet. The classification layer is removed, and instead the backbone outputs a set of convolution feature maps. The idea is, the classification model learned to identify objects in the process of classifying images, and now we can use that information to build a more complex model that can find those objects in the image. We want to pretrain because training the backbone at the same time as training the object detector tends to be very unstable.\n",
"\n",
"One additional component that is sometimes added to the backbone is a `Fearure Pyramid Network`. This take the outputs of the backbone, and combines them to together into a new set of feature maps by perform both up and down convolutions. The idea is that the different sized feature maps will help the model detect images of different sizes. The feature pyramid also helps with this, by allowing the different feature maps to share information with each other.\n",
"\n",
"The outputs of the feature pyramid are then passed to the `Region Proposal Network` which is responsible for finding regions of the image that might contain an object (this is the first of the two stages). The RPN will output several hundred thousand regions, each with a probability of containing an object. We'll typically take the top few thousand most likely regions. Because these several thousand regions will usually have a lot of overlap, we perform [non-max supression](https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c), which removed regions with large areas of overlap. This gives us a set of `regions of interest` regions of the image that we think might contain an image.\n",
"\n",
"Next, we use those regions to crop out the corresponding sections of the feature maps that came from the feature pyramid network using a technique called [ROI align](https://firiuza.medium.com/roi-pooling-vs-roi-align-65293ab741db).\n",
"\n",
"We pass our cropped feature maps to the `box head` which classifies each region into either a specific object category, or as background. It also refines the position of the bounding box. In Mask RCNN, we also pass the feature maps to a `mask head` which produces a segmentation mask over the object.\n",
"\n",
"#### SageMakerCV Internals\n",
"\n",
"An important feature of Mask RCNN is its multiple heads. One head constructs a bounding box, while another creates a mask. These are referred to as the `ROI heads`. It's common for users to extend this and other two stage models by adding their own ROI heads. For example, a keypoint head it common. Doing so means modifying SageMakerCV's internals, so let's talk about those for a second. \n",
"\n",
"The high level Mask RCNN model can be found in `amazon-sageamaker-cv/pytorch/sagemakercv/detection/detector/generatlized_rcnn.py`. If you trace through the forward function, you'll see that the model first passes an image through the backbone (which also contains the feature pyramid), then the RPN in the graphable module. Then results are then passed through non-max suppression, and into the roi heads. \n",
"\n",
"Probably the most important feature to be aware of are the `build` imports at the top. Each section of the model has an associated build function `(build_backbone, build_rpn, build_roi_heads)`. These functions simplify building the model by letting us pass in a single configuration file for building all the different pieces. \n",
"\n",
"For example, if you open `amazon-sageamaker-cv/pytorch/sagemakercv/detection/roi_heads/roi_heads.py`, you'll find the `build_roi_heads` function at the bottom. To add a new head, you would write a torch module with its own build function, and call the build function from here.\n",
"\n",
"For example, say you want to add a keypoint head to the model. An example keypoint module and associate build function is in `amazon-sageamaker-cv/pytorch/sagemakercv/detection/roi_heads/keypoint_head/keypoint_head.py`. To enable the keypoint head, you would set `cfg.MODEL.KEYPOINY_ON=True` and add the keypoint parameters to your configuration yaml file.\n",
"\n",
"SageMakerCV uses similar build functions for the optimizers and schedulers, which you can add to or modify in the `amazon-sageamaker-cv/pytorch/sagemakercv/training/optimizers/` directory. \n",
"\n",
"Finally, data loading tools are located in `amazon-sagemaker-cv/pytorch/data/`. Here you can add a new dataset, sampler, and preprocessing data transformations. Data loaders are constructed in the `build.py` file. Notice the `@DATASETS.register(\"COCO\")` decorator at the top of the COCO `make_coco_dataloader` function. This adds the function to a dictionary of datasets, so that when you specify `COCO` in yout configuration file, the `make_data_loader` knows which data loader to create.\n",
"\n",
"#### Setting Up Training\n",
"\n",
"Let's actually use some of these functions to train a model.\n",
"\n",
"Start by importing the default configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from configs import cfg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"We use the [yacs](https://github.com/rbgirshick/yacs) format for configuration files. If you want to see the entire config, run `print(cfg.dump())` but this prints out a lot, and to not overwhelm you with too much information, we'll just focus on the bits we want to change for this model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"First, let's put in all the file directories for the data and weights we downloaded in the previous section, as well as an output directory for the model results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.INPUT.TRAIN_INPUT_DIR = os.path.join(COCO_DIR, \"train2017\")\n",
"cfg.INPUT.VAL_INPUT_DIR = os.path.join(COCO_DIR, \"val2017\")\n",
"cfg.INPUT.TRAIN_ANNO_DIR = os.path.join(COCO_DIR, \"annotations\", \"instances_train2017.json\")\n",
"cfg.INPUT.VAL_ANNO_DIR = os.path.join(COCO_DIR, \"annotations\", \"instances_val2017.json\")\n",
"cfg.OUTPUT_DIR = os.path.join(LOCAL_DATA_DIR, \"model-output\")\n",
"# create output dir if it doesn't exist\n",
"os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
"# backbone weights file\n",
"cfg.MODEL.WEIGHT=os.path.join(WEIGHTS_DIR, R50_WEIGHTS)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"****\n",
"\n",
"Next we need to setup our data loader. The data loader is a tool in PyTorch that handles loading and applying transformations on our dataset. For more information see the [PyTorch Dataset and Data Loader Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).\n",
"\n",
"For our purposes, we need to set two parameters. The size divisibility has to do with downsampling that occurs in the backbone and feature pyramid. The feature maps generated by these layers must evenly divide into the original image size. For example, if the input image is 1344x1344, the smallest feature map will be 42x42, which means the input image must be divisible by 32 `(42x32=1344)`. The `SIZE_DIVISIBLITY` parameter makes sure all images are resized to be multiples of 32. \n",
"\n",
"The number of workers has to do with the number of background processes that will run the data loader in parallel. This can be useful for really high performance systems, to make sure our data loader is feeding data to the GPUs fast enough. However, for prototyping, these background processes can cause memory problems, so we'll turn it off when running in a notebook. We'll switch them back on later for our larger training job."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dataloader settings\n",
"cfg.DATALOADER.SIZE_DIVISIBILITY=32\n",
"cfg.DATALOADER.NUM_WORKERS=0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"This section specifies model details, including the type of model, and internal hyperparameters. We wont cover the details of all of these, but more information can be found in this blog posts listed above, as well as the original paper."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.MODEL.META_ARCHITECTURE=\"GeneralizedRCNN\" # The type of model we're training. found in amazon-sagemaker-cv/pytorch/sagemakercv/detection/detector/generalized_rcnn.py\n",
"cfg.MODEL.RESNETS.TRANS_FUNC=\"BottleneckWithFixedBatchNorm\" # Type of bottleneck function in the Resnet50 backbone. see https://arxiv.org/abs/1512.03385\n",
"cfg.MODEL.BACKBONE.CONV_BODY=\"R-50-FPN\" # Type of backbone, Resnet50 with feature pyramid network\n",
"cfg.MODEL.BACKBONE.OUT_CHANNELS=256 # number of channels on the output feature maps from the backbone\n",
"cfg.MODEL.RPN.USE_FPN=True # Use Feature Pyramid. RPN needs to know this since FPN adds an extra feature map\n",
"cfg.MODEL.RPN.ANCHOR_STRIDE=(4, 8, 16, 32, 64) # positions of anchors, see blog posts for details\n",
"cfg.MODEL.RPN.PRE_NMS_TOP_N_TRAIN=2000 # top N anchors to keep before non-max suppression during training\n",
"cfg.MODEL.RPN.PRE_NMS_TOP_N_TEST=1000 # top N anchors to keep before non-max suppression during testing\n",
"cfg.MODEL.RPN.POST_NMS_TOP_N_TEST=1000 # top N anchors to keep after non-max suppression during testing\n",
"cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN=1000 # top N anchors to keep after non-max suppression during training\n",
"cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST=1000 # top N anchors to keep before non-max suppression during training\n",
"cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_PER_IMAGE=True # Run NMS per FPN level\n",
"cfg.MODEL.RPN.LS=0.1 # label smoothing improves performance on less common categories\n",
"\n",
"# ROI Heads\n",
"cfg.MODEL.ROI_HEADS.USE_FPN=True # Use Feature Pyramid. ROI needs to know this since FPN adds an extra feature map\n",
"cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS=(10., 10., 5., 5.) # Regression wieghts for bounding boxes, see blog posts\n",
"cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION=7 # Pixel size of region cropped from feature map\n",
"cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES=(0.25, 0.125, 0.0625, 0.03125) # Pooling for ROI align\n",
"cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO=2 # Sampling for ROI Align\n",
"cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR=\"FPN2MLPFeatureExtractor\" # Type of ROI feature extractor found in SageMakerCV core utils\n",
"cfg.MODEL.ROI_BOX_HEAD.PREDICTOR=\"FPNPredictor\" # Predictor type used for inference found in SageMakerCV core utils\n",
"cfg.MODEL.ROI_BOX_HEAD.LOSS=\"GIoULoss\" # Use GIoU loss, improves box performance https://giou.stanford.edu/GIoU.pdf\n",
"cfg.MODEL.ROI_BOX_HEAD.DECODE=True # Convert boxes to pixel positions\n",
"cfg.MODEL.ROI_BOX_HEAD.CARL=True # Use carl loss https://arxiv.org/pdf/1904.04821.pdf\n",
"cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES=(0.25, 0.125, 0.0625, 0.03125) # Mask head ROI align\n",
"cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR=\"MaskRCNNFPNFeatureExtractor\" # Mask feature extractor type in SageMakerCV core utils\n",
"cfg.MODEL.ROI_MASK_HEAD.PREDICTOR=\"MaskRCNNC4Predictor\" # Predictor used for inference\n",
"cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION=14 # Pixel size of region cropped from feature map\n",
"cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO=2 # ROI align sampling ratio\n",
"cfg.MODEL.ROI_MASK_HEAD.RESOLUTION=28 # output resolution of mask\n",
"cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR=False # share feature extractor between box and mask heads\n",
"cfg.MODEL.MASK_ON=True # use mask head"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Next we set up the configuration for training, including the optimizer, hyperparameters, batch size, and training length. Batch size is global, so if you set a batch size of 64 across 8 GPUs, it will be a batch size of 8 per GPU. SageMakerCV currently supports the following optimizere: SGD (stochastic gradient descent), Adam, Lamb, and NovoGrad [link - this speeds up training by allowing increased batch sizes], and the following learning rate schedulers: stepwise and cosine decay. New, custom optimizers and schedulers can be added by modifying the `sagemakercv/training/build.py` file.\n",
"\n",
"For training on Studio, we'll just run for a few hundred steps. We'll be using SageMaker training instances for the full training on multiple GPUs.\n",
"\n",
"We also set the mixed precision optimization level. This is a value between O0-O4. O0 means no optimization, and training purely in FP32. O1-O3 are varying degrees of mixed precision, explained in the [Nvidia Apex documentation](https://nvidia.github.io/apex/amp.html). O4 is a special optimization level that combines pure FP16 training with channel last memory optimizations. This is the optimization we used to achieve the performance we announced at Re:Invent in 2020. However, optimization levels O2-O4 introduce some numerical instability, and can take a lot of tuning to train properly. For most users, we recommend starting on O1, since this provides most performance benefits of mixed precision, while retaining good numeric stability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.SOLVER.OPTIMIZER=\"NovoGrad\" # Type of optimizer, NovoGrad, Adam, SGD, Lamb\n",
"cfg.SOLVER.BASE_LR=0.004 # Learning rate after warmup [Suggested values]\n",
"cfg.SOLVER.BETA1=0.9 # Beta value for Novograd, Adam, and Lamb\n",
"cfg.SOLVER.BETA2=0.4 # Beta value for Novograd, Adam, and Lamb\n",
"cfg.SOLVER.ALPHA=.1 # Alpha for final value of cosine decay\n",
"cfg.SOLVER.LR_SCHEDULE=\"COSINE\" # Decay type, COSINE or MULTISTEP\n",
"cfg.SOLVER.IMS_PER_BATCH=8 # Global training batch size, must be a multiple of the number of GPUs\n",
"cfg.SOLVER.WEIGHT_DECAY=0.0005 # Training weight decay applied as decoupled weight decay on optimizer. [Paper link]\n",
"cfg.SOLVER.MAX_ITER=2500 # Total number of training steps for local training\n",
"cfg.SOLVER.WARMUP_FACTOR=.01 # Starting learning rate as a multiple of the BASE_LR\n",
"cfg.SOLVER.WARMUP_ITERS=100 # Number of warmup steps to reach BASE_LR\n",
"cfg.SOLVER.GRADIENT_CLIPPING=0.0 # Gradient clipping norm, leave as 0.0 to disable gradient clipping\n",
"cfg.OPT_LEVEL=\"O1\" # Mixed precision optimization level\n",
"cfg.TEST.IMS_PER_BATCH=16 # Evaluation batch size, must be a multiple of the number of GPUs\n",
"cfg.TEST.PER_EPOCH_EVAL=True # Eval after every epoch or only at the end of training for local training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Finally, SageMakerCV includes a number of training hooks, ie. tools that will trigger on certain events during training. \n",
"\n",
"For example, the `DetectronCheckpointHook` tells the trainer to read and write model checkpoints in the [Detectron](https://github.com/facebookresearch/detectron2) format. This will read in a checkpoint at the beginning of training, if one is provided, and write a checkpoint after each epoch. In this case, our backbone weights will be the starting checkpoint.\n",
"\n",
"The `AMP_Hook` applies automatic mixed precision to the model, based on the optimization level we set.\n",
"\n",
"The `IterTimerHook` and `TextLoggerHook` record information about training step time and loss values at each iteration, and format the results to be easy to read in AWS CloudWatch.\n",
"\n",
"The `COCOEvaluation` run evaluation either after each epoch, or at the end of training.\n",
"\n",
"We can supply these hooks to the configuration as a list of strings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.HOOKS=[\"DetectronCheckpointHook\",\n",
" \"AMP_Hook\",\n",
" \"IterTimerHook\",\n",
" \"TextLoggerHook\",\n",
" \"COCOEvaluation\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Let's save this configuration file as a yaml, so we have a record of how we created this training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from contextlib import redirect_stdout\n",
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"local_config_file = f\"configs/local-config-studio.yaml\"\n",
"with open(local_config_file, 'w') as outfile:\n",
" with redirect_stdout(outfile): print(cfg.dump())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Here's how you can load this saved yaml file and map it to the configuration for future training sessions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.merge_from_file(local_config_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Now let's actually run some training. Since we expect users will want to modify the training for their own cases, we'll use some of the training tools more directly than normal.\n",
"\n",
"explain why we want to do this interactively all in train.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from sagemakercv.detection.detector import build_detection_model # takes cfg builds model\n",
"from sagemakercv.training import make_optimizer, make_lr_scheduler # takes cfg builds opt sch\n",
"from sagemakercv.data import make_data_loader, Prefetcher # takes cfg builds data loader\n",
"from sagemakercv.utils.runner import build_hooks, Runner # model trainer\n",
"from sagemakercv.training.trainers import train_step # actual train step\n",
"from sagemakercv.utils.runner.hooks.checkpoint import DetectronCheckpointHook # need to run this hook first, explain why"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Create our dataloader\n",
"\n",
"`num_iterations` is the number of expected steps per epoch. It's the size of the dataset divided by the global batch size. It's used so that we can keep track of when we've reached the end of an epoch. For this small local training, we'll actually be training for fewer steps than a full epoch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_coco_loader, num_iterations = make_data_loader(cfg, is_distributed=False) #local traiining explain num_iterations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"The prefetcher boosts performance by asynchronously grabbing the next training element, and sending it to the GPU before it's needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(cfg.MODEL.DEVICE) # tell torch to use the GPU not the CPU as in default config\n",
"train_iterator = Prefetcher(iter(train_coco_loader), device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Build the model, optimizer, and scheduler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = build_detection_model(cfg)\n",
"model.to(device) #send to GPU\n",
"optimizer = make_optimizer(cfg, model)\n",
"scheduler = make_lr_scheduler(cfg, optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Build the hooks we set in the config."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hooks = build_hooks(cfg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Now we're to build a runner, which is a training tool designed for specifically for SageMaker. It manages your training steps, hooks, and logs, making it easier to track models on CloudWatch as well as managing sending data back and forth from S3. To build the runner, we need to define a `train_step`, which is just the standard pytorch training step. This is similar to Keras in TensorFlow but more customizable. For example, a very basic train_step would be something like:\n",
"\n",
"```\n",
"def train_step(inputs, model, optimizer, scheduler):\n",
" optimizer.zero_grad()\n",
" loss = model(inputs)\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" return loss\n",
"```\n",
"\n",
"In this case, the `train_step` can be found in `sagemakercv/training/trainers.py` and includes all steps needed for Mask RCNN."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"runner = Runner(model, train_step, cfg, device, optimizer, scheduler)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Register hooks with the runner so it knows when to trigger them. [explain detectron]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for hook in hooks:\n",
" runner.register_hook(hook, priority='HIGHEST' if isinstance(hook, DetectronCheckpointHook) else 'NORMAL')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"And finally we're ready to run training. This will print a lot of info as it's setting up the model. When training starts, you'll sometimes see a few `Gradient Overflow` warnings. This is fine, it's just the mixed precision adjusting the loss scaling. After about a minute you should start seeing loss values in step increments of 50. At 2500 steps on a G4dn instane, this takes about an hour. You can speed this up by reducing the number of training steps, or using a P3.2xlarge instance. You should get a [MaP score](https://towardsdatascience.com/map-mean-average-precision-might-confuse-you-5956f1bfa9e2) of around .1 for both box and mask after 2500 steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"runner.run(train_iterator, num_iterations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"Okay, so we printed a lot of numbers. Before moving on, let's first visualize what our model learned so we can make sure it's working how we expect.\n",
"\n",
"SageMakerCV includes some simple visualization tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"import gc\n",
"from sagemakercv.data.datasets.evaluation.coco.coco_labels import coco_categories\n",
"from sagemakercv.utils.visualize import Visualizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"viz = Visualizer(model, cfg, temp_dir=cfg.OUTPUT_DIR, categories=coco_categories)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The visualizer can take a local file path, S3 location, or web address for an image. It will run the image through the model you just trained, and output it's predictions for objects it finds.\n",
"\n",
"You can use any images you want, either from the Coco data we downloaded earlier, or your own. For example, [pixalbay](https://pixabay.com/) has lots of free use images. Just copy the image address into the `image_src` below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_src = 'https://cdn.pixabay.com/photo/2021/07/29/14/48/new-york-6507350_1280.jpg'\n",
"#'https://cdn.pixabay.com/photo/2020/05/12/11/39/cat-5162540__480.jpg'\n",
"viz(image_src, threshold=0.9) # Threshold is the minimum probability to display. Usually set around .75 or .9 so we don't get a bunch of spots where the model says there's a 2% chance of hot dogs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"Note that running interactively in a notebook can build of GPU memory pressure over time. If you run inference a lot of times, you might eventually get a cuda out of memory error. If this happens, just run the garbage collection in the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"Before moving on, let's delete our model and related training tools so we don't keep too much in memory. The model you just trained is saved in your `cfg.OUTPUT_DIR` directory, so you can reload it later. We'll cover how to import a saved model in the final section."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"del model, optimizer, scheduler, runner\n",
"\n",
"torch.cuda.empty_cache()\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"\n",
"### Distributed Training\n",
"\n",
"Great! You've managed to train a model (albeit for just a few thousand steps) on your studio instance. For many users this will be all they need. If you only intend to deal with small amounts of data (<10G) that can be trained on a single GPU, you can do everything you need with what we've run so far.\n",
"\n",
"However, most practical applications of Mask RCNN require training on huge datasets across many GPUs or nodes. For that, we'll need to launch a SageMaker training job. Here we can train a model on as many as 512 [A100 GPUs](https://www.nvidia.com/en-us/data-center/a100/). We won't go quite that far. Instead, let's try training on 32 A100 GPUs across 4 P4d nodes. We'll also cover how to modify the configuration for different GPU types and counts.\n",
"\n",
"The section below is also replicated in the `SageMaker.ipynb` notebook for future training once all the above setup is complete.\n",
"\n",
"Before we get started, a few notes about how SageMaker training instances work. SageMaker takes care of a lot of setup for you, but it's important to understand a little of what's happening under the hood so you can customize training to your own needs. \n",
"\n",
"First we're going to look at a toy estimator to explain what's happening:\n",
"\n",
"```\n",
"from sagemaker import get_execution_role\n",
"from sagemaker.pytorch import PyTorch\n",
"\n",
"estimator = PyTorch(\n",
" entry_point='train.py', \n",
" source_dir='.', \n",
" py_version='py3',\n",
" framework_version='1.8.1',\n",
" role=get_execution_role(),\n",
" instance_count=4,\n",
" instance_type='ml.p4d.24xlarge',\n",
" distribution=distribution,\n",
" output_path='s3://my-bucket/my-output/',\n",
" checkpoint_s3_uri='s3://my-bucket/my-checkpoints/',\n",
" model_dir='s3://my-bucket/my-model/',\n",
" hyperparameters={'config': 'my-config.yaml'},\n",
" volume_size=500,\n",
" code_location='s3://my-bucket/my-code/',\n",
")\n",
"```\n",
"\n",
"The estimator forms the basic configuration of your training job.\n",
"\n",
"SageMaker will first launch `instance_count=4` `instance_type=ml.p4d.24xlarge` instances. The `role` is an IAM role that SageMaker will use to launch instances on your behalf. SageMaker includes a `get_execution_role` function which grabs the execution role of your current instance. Each instance will have a `volume_size=500` EBS volume attached for your model and data. On `ml.p4d.24xlarge` and `ml.p3dn.24xlarge` instance types, SageMaker will automatically set up the [Elastic Fabric Adapter](https://aws.amazon.com/hpc/efa/). EFA provides up to 400 GB/s communication between your training nodes, as well as [GPU Direct RDMA](https://aws.amazon.com/about-aws/whats-new/2020/11/efa-supports-nvidia-gpudirect-rdma/) on `ml.p4d.24xlarge`, which allows your GPUs to bypass the host and communicate directly with each other across nodes.\n",
"\n",
"Next, SageMaker we copy all the contents of `source_dir='.'` first to the `code_location='s3://my-bucket/my-code/'` S3 location, then to each of your instances. One common mistake is to leave large files or data in this directory or its subdirectories. This will slow down your launch times, or can even cause the launch to hang. Make sure to keep your working data and model artifacts elsewhere on your Studio instance so you don't accidently copy them to your training instance. You should instead use `Channels` to copy data and model artifacts, which we'll cover shortly.\n",
"\n",
"SageMaker will then download the training Docker image to all your instances. Which container you download is determined by `py_version='py3'` and `framework_version='1.8.1'`. You can also use your own [custom Docker image](https://aws.amazon.com/blogs/machine-learning/bringing-your-own-custom-container-image-to-amazon-sagemaker-studio-notebooks/) by specifying an ECR address with the `image_uri` option. When building a custom container, we recommend building the [CUDA utilities](https://github.com/aws-samples/amazon-sagemaker-cv/tree/main/pytorch/cuda_utils) from source within the container, to ensure compatibility with your PyTorch version. SageMakerCV currently include prebuilt whls for CUDA utilities for versions 1.6-1.9.\n",
"\n",
"Before starting training, SageMaker will check your source directory for a `setup.py` file, and install if one is present. Then SageMaker will actually launch training, via `entry_point='train.py'`. Anything in `hyperparameters={'config': 'my-config.yaml'}` will be passed to the training script as a command line argument (ie `python train.py --config my-config.yaml`). The distribution will determine what form of distributed training to launch. This will be covered in more detail later.\n",
"\n",
"During training, anything written to `/opt/ml/checkpoints` on your training instances will be synced to `checkpoint_s3_uri='s3://my-bucket/my-checkpoints/'` at the same time. This can be useful for checkpointing a model you might want to restart later, or for writting Tensorboard logs to monitor your training.\n",
"\n",
"When training complets, you can write your model artifacats to `/opt/ml/model` and it will save to `model_dir='s3://my-bucket/my-model/'`. Another option is to also write model artifacts to your checkpoints file.\n",
"\n",
"Training logs, and any failure messages will to written to `/opt/ml/output` and saved to `output_path='s3://my-bucket/my-output/'`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import get_execution_role\n",
"from sagemaker.pytorch import PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we need to set some names. You want `AWS_DEFAULT_REGION` to be the same region as the S3 bucket your created earlier, to ensure your training jobs are reading from nearby S3 buckets.\n",
"\n",
"Next, set a `user_id`. This is just for naming your training job so it's easier to find later. This can be anything you like. We also get the current date and time to make organizing training jobs a little easier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# explain region. Don't launch a training job in VA with S3 bucket in OR\n",
"os.environ['AWS_DEFAULT_REGION'] = region # This is the region we set at the beginning, when creating the S3 bucket for our data\n",
"\n",
"# this is all for naming\n",
"user_id=\"username-smcv-tutorial\" # This is used for naming your training job, and organizing your results on S3. It can be anything you like.\n",
"date_str=datetime.now().strftime(\"%d-%m-%Y\")\n",
"time_str=datetime.now().strftime(\"%d-%m-%Y-%H-%M-%S\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For instance type, we'll use an `ml.p4d.24xlarge`. We recommend this instance type for large training. It includes the latest A100 Nvidia GPUs, which can train several times faster than the previous generation. If you would rather train part way on smaller instanes, `ml.p3.2xlarge, ml.p3.8xlarge, ml.p3.16xlarge, ml.p3dn.24xlarge, ml.g4dn.12xlarge` are all good options. In particular, if you're looking for a low cost way to try a short distributed training, but aren't worried about the model fully converging, we recommend the `ml.g4dn.12xlarge` which uses 4 Nvidia T4 GPUs per node.\n",
"\n",
"`s3_location` will be the base S3 storage location we used earlier for the COCO data. For `role` we get the execution role from our studio instance. For `source_dir` we use the current directory. Again, make sure you haven't accidently written any large files to this directory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# specify training type, s3 src and nodes\n",
"instance_type=\"ml.p4d.24xlarge\" # This can be any of 'ml.p3dn.24xlarge', 'ml.p4d.24xlarge', 'ml.p3.16xlarge', 'ml.p3.8xlarge', 'ml.p3.2xlarge', 'ml.g4dn.12xlarge'\n",
"nodes=1 # 4\n",
"s3_location=os.path.join(\"s3://\", S3_BUCKET, S3_DIR)\n",
"role=get_execution_role() #give Sagemaker permission to launch nodes on our behalf\n",
"source_dir='.'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Let's modify our previous training configuration for multinode. We don't need to change much. We'll increase the batch size since we have more and large GPUs. For A100 GPUs a batch size of 12 per GPU works well. For V100 and T4 GPUs, a batch size of 6 per GPU is recommended. Make sure to lower the learning rate and increase your number of training steps if you decrease the batch size. For example, if you want to train on 2 `ml.g4dn.12xlarge` instances, you'll have 8 T4 GPUs. A batch size of `cfg.SOLVER.IMS_PER_BATCH=48`, with inference batch size of `cfg.TEST.IMS_PER_BATCH=32`, learning rate of `cfg.SOLVER.BASE_LR=0.006`, and training steps of `cfg.SOLVER.MAX_ITER=25000` is probably about right. \n",
"\n",
"The configuration below has been tested to converge to better than MLPerf accuracy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.SOLVER.OPTIMIZER=\"NovoGrad\" # Type of optimizer, NovoGrad, Adam, SGD, Lamb\n",
"cfg.SOLVER.BASE_LR=0.016 # 0.042 # Learning rate after warmup\n",
"cfg.SOLVER.BETA1=0.9 # Beta value for Novograd, Adam, and Lamb\n",
"cfg.SOLVER.BETA2=0.3 # Beta value for Novograd, Adam, and Lamb\n",
"cfg.SOLVER.ALPHA=.001 # Alpha for final value of cosine decay\n",
"cfg.SOLVER.LR_SCHEDULE=\"COSINE\" # Decay type, COSINE or MULTISTEP\n",
"cfg.SOLVER.IMS_PER_BATCH=96 #384 # Global training batch size, must be a multiple of the number of GPUs\n",
"cfg.SOLVER.WEIGHT_DECAY=0.001 # Training weight decay applied as decoupled weight decay on optimizer\n",
"cfg.SOLVER.MAX_ITER=15000 #5000 # Total number of training steps\n",
"cfg.SOLVER.WARMUP_FACTOR=.01 # Starting learning rate as a multiple of the BASE_LR\n",
"cfg.SOLVER.WARMUP_ITERS=625 # Number of warmup steps to reach BASE_LR\n",
"cfg.SOLVER.GRADIENT_CLIPPING=0.0 # Gradient clipping norm, leave as 0.0 to disable gradient clipping\n",
"cfg.OPT_LEVEL=\"O1\" # Mixed precision optimization level\n",
"cfg.TEST.IMS_PER_BATCH=64 # 128 # Evaluation batch size, must be a multiple of the number of GPUs\n",
"cfg.TEST.PER_EPOCH_EVAL=True # False # Eval after every epoch or only at the end of training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Earlier we mentioned the `distrbution` strategy in SageMaker. Distributed training can be either multi GPU single node (ie training on 8 GPU in a single ml.p4d.24xlarge) or mutli GPU multi node (ie training on 32 GPUs across 4 ml.p4d.24xlarges). For PyTorch you can use either [PyTorch's built in DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or [SageMaker Distributed Data Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html) (SMDDP). For single node multi GPU we recommend PyTorch DDP, while for multinode we recommend SMDDP. SMDDP is built to fully utilize AWS network topology, and EFA, providing a speed boost on multinode.\n",
"\n",
"To enable SMDDP, set `distribution = { \"smdistributed\": { \"dataparallel\": { \"enabled\": True } } }`. SageMakerCV already has SMDDP integrated. To implement SMDDP for your own models, follow [these instructions](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-intro.html). SMDDP will launch training from the first node in your cluster using [MPI](https://www.open-mpi.org/).\n",
"\n",
"For PyTorch DDP we'll actually disable SageMaker's distribution tool, and set it up manually. When you disable distribution, SageMaker will simply launch the same `entry_point` script on each node. The setup below is to run PyTorch DDP for multi GPU on a single node, but you can use the same setup to run PyTorch DDP multinode as well. Manually setting up your own distribution can be useful when you want to do complex custom distribution strategies. For example, you can use the `SM_CURRENT_HOST` environment variable on each node to set node specific parameters. PyTorch DDP requires running the `torch.distributed.launch` on each training node. The `launch_ddp.py` script grabs the environment variables, then launches `train.py` with `torch.distributed.launch`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if nodes>1 and instance_type in ['ml.p3dn.24xlarge', 'ml.p4d.24xlarge', 'ml.p3.16xlarge']:\n",
" distribution = { \"smdistributed\": { \"dataparallel\": { \"enabled\": True } } } \n",
" entry_point = \"train.py\"\n",
"else:\n",
" distribution = None\n",
" entry_point = \"launch_ddp.py\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"We'll set a job name based on the user name and time. We'll then set output directories on S3 using the date and job name.\n",
"\n",
"For this training, we'll use the same S3 location for all 3 SageMaker model outputs `/opt/ml/checkpoint`, `/opt/ml/model`, and `/opt/ml/output`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"job_name = f'{user_id}-{time_str}'\n",
"output_path = os.path.join(s3_location, \"sagemaker-output\", date_str, job_name)\n",
"code_location = os.path.join(s3_location, \"sagemaker-code\", date_str, job_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Next we need to add our data sources to our configuration file, but first let's talk a little more about how SageMaker gets data to your instance.\n",
"\n",
"The most straightforward way to get your data is using \"Channels.\" These are S3 locations you specify in a dictionary when you launch a training job. For example, let's say you launch a training job with:\n",
"\n",
"```\n",
"channels = {'train': 's3://my-bucket/data/train/',\n",
" 'test': 's3://my-bucket/data/test/',\n",
" 'weights': 's3://my-bucket/data/weights/',\n",
" 'dave': 's3://my-bucket/data/daves_weird_data/'}\n",
"\n",
"pytorch_estimator.fit(channels)\n",
"```\n",
"\n",
"At the start of training, SageMaker will create a set of corresponding directories on each training node:\n",
"\n",
"```\n",
"/opt/ml/input/data/train/\n",
"/opt/ml/input/data/test/\n",
"/opt/ml/input/data/weights/\n",
"/opt/ml/input/data/dave/\n",
"```\n",
"\n",
"SageMaker will then copy all the contents of the corresponding S3 locations to these directories, which you can then access in training.\n",
"\n",
"One downside of setting up channels like this is that it requires all the data to be downloaded to your instance at the start of of training, which can delay the training launch if you're dealing with a large dataset.\n",
"\n",
"We have two ways to speed up launch. The first is [Fast File Mode](https://aws.amazon.com/about-aws/whats-new/2021/10/amazon-sagemaker-fast-file-mode/) which downloads data from S3 as it's requested by the training model, speeding up your launch time. You can use fast file mode by sepcifying `TrainingInputMode='FastFile'` in your SageMaker estimator configuration. \n",
"\n",
"If you're dealing with really huge data, on the order of several terabytes, you might not want to keep if on the instance at all, and just stream it directly from S3 into your model. The [PyTorch S3 plugin](https://aws.amazon.com/blogs/machine-learning/announcing-the-amazon-s3-plugin-for-pytorch/) provides just this capability. To use the PyTorch S3 plugin, you need to build a PyTorch dataset using the S3 plugin base class. S3 plugin support is already built into SageMakerCV. If you want to see how it's implemented, the dataset can be found in `sagemakercv/data/datasets/coco.py`. \n",
"\n",
"In our case, we'll use a mix of channels and the S3 plugin. We'll download the smaller pieces at the start of training (the validation data, pretrained weights, and image annotations), and we'll use the S3 plugin to stream the training data, since it's large. SageMakerCV will automatically switch to the S3 plugin when you supply an S3 location in the configuration file. So all we need to do is setup our channels for the data we want to download, then give the config file the locations either on the instance or on S3.\n",
"\n",
"We also want to set the output dir to `/opt/ml/checkpoints`. SageMaker will sync the contents of this directory back to S3. Let's also turn up the number of workers so the data loads fast enough for the bigger GPUs we're using. \n",
"\n",
"First, we setup our training channels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"channels = {'validation': os.path.join(S3_DATA_LOCATION, 'val2017'),\n",
" 'weights': S3_WEIGHTS_LOCATION,\n",
" 'annotations': os.path.join(S3_DATA_LOCATION, 'annotations')}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Next, in the configuration file, we need to provide the corresponding location on each training node for these channels. We need to specify locations for the validation data, train and validation annotations, and the backbone weights."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CHANNELS_DIR='/opt/ml/input/data/' # on node\n",
"cfg.INPUT.VAL_INPUT_DIR = os.path.join(CHANNELS_DIR, 'validation') # Corresponds to the vdalidation key in the channels\n",
"cfg.INPUT.TRAIN_ANNO_DIR = os.path.join(CHANNELS_DIR, 'annotations', 'instances_train2017.json')\n",
"cfg.INPUT.VAL_ANNO_DIR = os.path.join(CHANNELS_DIR, 'annotations', 'instances_val2017.json')\n",
"cfg.MODEL.WEIGHT=os.path.join(CHANNELS_DIR, 'weights', R50_WEIGHTS) # backbone weights file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For our training data location, we'll instead point it directly to S3 so it streams the data in."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.INPUT.TRAIN_INPUT_DIR = os.path.join(S3_DATA_LOCATION, \"train2017\") # Set to S3 location so we use the S3 plugin"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the output directory to the SageMaker's checkpoint directory. This way all the files our model writes out will be immediately copied to S3."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.OUTPUT_DIR = '/opt/ml/checkpoints/'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Turn up the number of workers for the dataloader. This is especially important when using the S3 plugin. The plugin is fast, but there's still a bit of network overhead. Having more dataloaders grabbing elements from S3 simultaneously will mean our model isn't waiting around for data to load. A good rule of thunb is to set the number of dataloader workers equal to the number of vCPUs divded by the number of GPUs per instance. For example, on the `ml.p4d.24xlarge` there are 96 vCPUs and 8 GPUs, so we'll use 12 workers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.DATALOADER.NUM_WORKERS=12 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg.HOOKS=[\"DetectronCheckpointHook\",\n",
" \"AMP_Hook\",\n",
" \"IterTimerHook\",\n",
" \"TextLoggerHook\",\n",
" \"COCOEvaluation\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Like we did for the local training, we'll save the configuration file so we can replicate this same training later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dist_config_file = f\"configs/dist-training-config.yaml\"\n",
"with open(dist_config_file, 'w') as outfile:\n",
" with redirect_stdout(outfile): print(cfg.dump())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"And now we can launch our training job. Let's set the config file we just created as a hyperparameter for the estimator. This will get passed to our training script as a command line argument. For example, SageMaker would launch training with `python train.py --config configs/dist-training-config.yaml`\n",
"\n",
"Using 4 P4d nodes, training takes about 45 minutes. This section will also print a lot of output logs. By setting `wait=False` you can avoid printing logs in the notebook. This setting will just launch the job then return, and is useful for when you want to launch several jobs at the same time. You can then montior each job from the [SageMaker Training Console](https://us-west-2.console.aws.amazon.com/sagemaker)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparameters = {\"config\": dist_config_file}\n",
"\n",
"estimator = PyTorch(\n",
" entry_point=entry_point, \n",
" source_dir=source_dir, \n",
" py_version='py3',\n",
" framework_version='1.8.1',\n",
" role=role,\n",
" instance_count=nodes,\n",
" instance_type=instance_type,\n",
" distribution=distribution,\n",
" output_path=output_path,\n",
" checkpoint_s3_uri=output_path,\n",
" model_dir=output_path,\n",
" hyperparameters=hyperparameters,\n",
" volume_size=500,\n",
" code_location=code_location,\n",
" disable_profiler=True, # Reduce number of logs since we don't need profiler or debugger for this training\n",
" debugger_hook_config=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator.fit(channels, wait=True, job_name=job_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"### Part 5: Visualizing Results\n",
"\n",
"And there you have it, a fully trained Mask RCNN model in under an hour. Now let's see how our model does on prediction by actually visualizing the output.\n",
"\n",
"Our model is stored at the S3 location we gave to the training job in `output_path`. We'll need to grab the results and store them on our studio instance so we can check performance, and visualize the output."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from s3fs import S3FileSystem\n",
"from configs import cfg\n",
"\n",
"import torch\n",
"import gc\n",
"from sagemakercv.detection.detector import build_detection_model\n",
"from sagemakercv.utils.model_serialization import load_state_dict\n",
"from sagemakercv.data.datasets.evaluation.coco.coco_labels import coco_categories\n",
"from sagemakercv.utils.visualize import Visualizer\n",
"\n",
"# Turn down logging, we don't need all the info about loading the state dict\n",
"import logging\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.CRITICAL)\n",
"\n",
"# Reuse the local configuration file we made earlier\n",
"cfg.merge_from_file(local_config_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s3fs = S3FileSystem()\n",
"\n",
"TRAINING_OUTPUT=os.path.join(LOCAL_DATA_DIR, 'training_output') # Make a new local directory for our training results and new images\n",
"os.makedirs(TRAINING_OUTPUT, exist_ok=True)\n",
"\n",
"# Grab the filename of the last checkpoint our training job wrote\n",
"s3fs.get(os.path.join(output_path, 'last_checkpoint'), os.path.join(TRAINING_OUTPUT, 'last_checkpoint'))\n",
"with open(os.path.join(TRAINING_OUTPUT, 'last_checkpoint'), 'r') as f:\n",
" checkpoint_name=f.readline().split('/')[-1]\n",
" \n",
"# Copy the saved weights to our local directory\n",
"s3fs.get(os.path.join(output_path, checkpoint_name), os.path.join(TRAINING_OUTPUT, checkpoint_name))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"To run inference, we first build a blank model and then add weights from the distributed training job.\n",
"\n",
"We can build a new model the same way we built one for local training. This time, instead of using the checkpointer to load the backbone weights, we'll directly load all the saved weights from the trained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(cfg.MODEL.DEVICE)\n",
"model = build_detection_model(cfg)\n",
"_ = model.to(device)\n",
"\n",
"# Load weights file as a dictionary\n",
"weights = torch.load(os.path.join(TRAINING_OUTPUT, checkpoint_name))['model']\n",
"\n",
"# Map saved weights to our model\n",
"load_state_dict(model, weights, False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"***\n",
"Like we did with the local training, we can pass the model to the visualization tool, and give it an image either from local, S3, or a web address."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"viz = Visualizer(model, cfg, temp_dir=TRAINING_OUTPUT, categories=coco_categories)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_src = 'https://cdn.pixabay.com/photo/2021/07/29/14/48/new-york-6507350_1280.jpg'\n",
"#'https://cdn.pixabay.com/photo/2020/05/12/11/39/cat-5162540__480.jpg'\n",
"viz(image_src, threshold=0.9) #explain threshold"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Conclusion\n",
"\n",
"In this notebook, we've walked through the entire process of training Mask RCNN on SageMaker. We've implemented several of SageMaker's more advanced features, such as distributed training, EFA, and streaming data directly from S3. From here you can use the provided template datasets to train on your own data, or modify the framework with your own object detection model.\n",
"\n",
"When you're done, make sure to check that all of your SageMaker training jobs have stopped by checking the [SageMaker Training Console](https://us-west-2.console.aws.amazon.com/sagemaker). Also check that you've stopped any Studio instance you have running by selecting the session monitor on the left (the circle with a square in it), and clicking the power button next to any running instances. Your files will still be saved on the Studio EBS volume.\n",
"\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"instance_type": "ml.g4dn.xlarge",
"kernelspec": {
"display_name": "Python 3 (PyTorch 1.6 Python 3.6 GPU Optimized)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.6-gpu-py36-cu110-ubuntu18.04-v3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}