# FairMOT Training in Amazon SageMaker

This notebook demonstrates how to train a [FairMOT](https://arxiv.org/abs/2004.01888) model with SageMaker and tune hyper-parameters with [SageMaker Hyperparameter tuning job](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html).

## 1. SageMaker Initialization 
First we upgrade SageMaker to the latest version. If your notebook is already using the latest SageMaker 2.x API, you may skip the next cell.

In [None]:
! pip install --upgrade pip
! python3 -m pip install --upgrade sagemaker

In [None]:
%%time
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.estimator import Estimator

role = get_execution_role() # provide a pre-existing role ARN as an alternative to creating a new role
print(f'SageMaker Execution Role:{role}')

client = boto3.client('sts')
account = client.get_caller_identity()['Account']
print(f'AWS account:{account}')

session = boto3.session.Session()
region = session.region_name
print(f'AWS region:{region}')

In [None]:
s3_bucket = sagemaker.Session().default_bucket() 

# we use data parallel to train a model on a single instance as https://github.com/ifzhang/FairMOT
version_name = "dp"

# Currently we support MOT17 and MOT20
dataset_name= "MOT17" # Options: MOT17, MOT20

# 0: set all data to train data, 1: set second half part to validation data
# set 1 when executing hyperparameter tuning job
half_val = 1

training_image = f"{account}.dkr.ecr.{region}.amazonaws.com/fairmot-sagemaker:pytorch1.8-{version_name}"

## 2. Stage dataset in Amazon S3

We use the dataset from [MOT Challenge](https://motchallenge.net) for training. First, we download the dataset to this notebook instance. By referencing [DATA ZOO](https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/DATASET_ZOO.md), we prepare the dataset which can be trained by `FairMOT`, and upload the processed dataset to the Amazon [S3 bucket](https://docs.aws.amazon.com/en_pv/AmazonS3/latest/gsg/CreatingABucket.html).

In [None]:
!cat ./prepare-s3-bucket.sh

Using your *Amazon S3 bucket*, *dataset name* and *validation flag* as arguments, run the script [`prepare-s3-bucket.sh`](prepare-s3-bucket.sh). You can skip this step if you have already uploaded the dataset to S3 bucket.

In [None]:
%%time
!./prepare-s3-bucket.sh {s3_bucket} {dataset_name} {half_val}

## 3. Build and push SageMaker training image
We use the implementation of [FairMOT](https://github.com/ifzhang/FairMOT) to create our own container, and push the image to [Amazon ECR](https://aws.amazon.com/ecr/).

### Docker Environment Preparation
Because the volume size of container may be larger than the available size in root directory of the notebook instance, we need to put the directory of docker data into the ```/home/ec2-user/SageMaker/docker``` directory.

By default, the root directory of docker is set as ```/var/lib/docker/```. We need to change the directory of docker to ```/home/ec2-user/SageMaker/docker```.

In [None]:
!cat /etc/docker/daemon.json

In [None]:
!bash ./prepare-docker.sh

### Build training image for FairMOT
Use script [`./container/build_tools/build_and_push.sh`](./container-dp/build_tools/build_and_push.sh) to build and push the FairMOT training image to [Amazon ECR](https://aws.amazon.com/ecr/).

In [None]:
!cat ./container-{version_name}/build_tools/build_and_push.sh

Using your *AWS region* as argument, run the cell below.

In [None]:
%%time
!bash ./container-{version_name}/build_tools/build_and_push.sh {region}

## 4. Define SageMaker Data Channels
In this step, we define SageMaker `train` data channel. 

In [None]:
from sagemaker.inputs import TrainingInput
prefix = "fairmot/sagemaker" #prefix in your S3 bucket
s3train = f's3://{s3_bucket}/{prefix}/input/train'

train_input = TrainingInput(s3_data=s3train, 
                            distribution="FullyReplicated", 
                            s3_data_type='S3Prefix', 
                            input_mode='File')

data_channels = {'train': train_input}

Next, we define the model output location in S3 bucket.

In [None]:
s3_output_location = f's3://{s3_bucket}/{prefix}/output'

## 5. Configure Hyper-parameters
In this step, we define the hyper-parameters used in FairMOT. Jump to [8.Hyperparameter Tuning](#hyperparametertuning) if you want to run hyperparameter tuning job.

<table align='left'>
    <caption>FairMOT  Hyper-parameters</caption>
    <tr>
    <th style="text-align:center">Hyper-parameter</th>
    <th style="text-align:center">Description</th>
    <th style="text-align:center">Default</th>
    </tr>
     <tr>
        <td style="text-align:center">arch</td>
        <td style="text-align:left">model architecture. Currently tested resdcn_34 | resdcn_50 | resfpndcn_34 | dla_34 | hrnet_18</td>
        <td style="text-align:center">'dla_34'</td>
    </tr>
    <tr>
        <td style="text-align:center">load_model</td>
        <td style="text-align:left">pretrained model</td>
        <td style="text-align:center">fairmot_dla34.pth</td>
    </tr>
    <tr>
        <td style="text-align:center">head_conv</td>
        <td style="text-align:left">conv layer channels for output head 0 for no conv layer -1 for default setting: 256 for resnets and 256 for dla.</td>
        <td style="text-align:center">-1</td>
    </tr>
    <tr>
        <td style="text-align:center">down_ratio</td>
        <td style="text-align:left">output stride. Currently only supports 4.</td>
        <td style="text-align:center">4</td>
    </tr>
    <tr>
        <td style="text-align:center">input_res</td>
        <td style="text-align:left">input height and width. -1 for default from dataset. Will be overriden by input_h | input_w</td>
        <td style="text-align:center">-1</td>
    </tr>
    <tr>
        <td style="text-align:center">input_h</td>
        <td style="text-align:left">input height</td>
        <td style="text-align:center">608</td>
    </tr>
    <tr>
        <td style="text-align:center">input_w</td>
        <td style="text-align:left">input width</td>
        <td style="text-align:center">1088</td>
    </tr>
    <tr>
        <td style="text-align:center">lr</td>
        <td style="text-align:left">learning rate for batch size 12.</td>
        <td style="text-align:center">1e-4</td>
    </tr>
    <tr>
        <td style="text-align:center">lr_step</td>
        <td style="text-align:left">drop learning rate by 10.</td>
        <td style="text-align:center">'20'</td>
    </tr>
    <tr>
        <td style="text-align:center">num_epochs</td>
        <td style="text-align:left">total training epochs.</td>
        <td style="text-align:center">30</td>
    </tr>
    <tr>
        <td style="text-align:center">batch_size</td>
        <td style="text-align:left">batch size, 8 is recommended when using ml.p3 instance</td>
        <td style="text-align:center">8</td>
    </tr>
    <tr>
        <td style="text-align:center">num_iters</td>
        <td style="text-align:left">default: #samples / batch_size.</td>
        <td style="text-align:center">-1</td>
    </tr>
    <tr>
        <td style="text-align:center">val_intervals</td>
        <td style="text-align:left">number of epochs to run validation.</td>
        <td style="text-align:center">5</td>
    </tr>
    <tr>
        <td style="text-align:center">reg_loss</td>
        <td style="text-align:left">regression loss: sl1 | l1 | l2</td>
        <td style="text-align:center">'l1'</td>
    </tr>
    <tr>
        <td style="text-align:center">hm_weight</td>
        <td style="text-align:left">loss weight for keypoint heatmaps.</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">off_weight</td>
        <td style="text-align:left">loss weight for keypoint local offsets.</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">wh_weight</td>
        <td style="text-align:left">loss weight for bounding box size.</td>
        <td style="text-align:center">0.1</td>
    </tr>
    <tr>
        <td style="text-align:center">id_loss</td>
        <td style="text-align:left">reid loss: ce | focal</td>
        <td style="text-align:center">'ce'</td>
    </tr>
    <tr>
        <td style="text-align:center">id_weight</td>
        <td style="text-align:left">loss weight for id</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">reid_dim</td>
        <td style="text-align:left">feature dim for reid</td>
        <td style="text-align:center">128</td>
    </tr>
</table>

In [None]:
hyperparameters = {
                    "batch_size": 8,
                    "num_epochs": 20,
                    "val_intervals": 1,
                    "load_model": 'fairmot_dla34.pth',
                    "data_name": "MOT17"
                  }

## 6. Define Training Metrics
Next, we define the regular expressions that SageMaker uses to extract algorithm metrics from training logs and send them to [AWS CloudWatch metrics](https://docs.aws.amazon.com/en_pv/AmazonCloudWatch/latest/monitoring/working_with_metrics.html). These algorithm metrics are visualized in SageMaker console.

In [None]:
metric_definitions=[
            {
                "Name": "train_loss",
                "Regex": "\|train_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_hm_loss",
                "Regex": "\|train_hm_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_wh_loss",
                "Regex": "\|train_wh_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_id_loss",
                "Regex": "\|train_id_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_off_loss",
                "Regex": "\|train_off_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_loss",
                "Regex": "\|val_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_hm_loss",
                "Regex": "\|val_hm_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_wh_loss",
                "Regex": "\|val_wh_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_id_loss",
                "Regex": "\|val_id_loss\\s*(\\S+).*"
            }, 
            {
                "Name": "val_off_loss",
                "Regex": "\|val_off_loss\\s*(\\S+).*"
            }
    ]

## 7. Define SageMaker Training Job

Next, we use SageMaker [Estimator](https://sagemaker.readthedocs.io/en/stable/estimators.html) API to define a SageMaker Training Job.

In [None]:
sagemaker_session = sagemaker.session.Session(boto_session=session)

fairmot_estimator = Estimator(image_uri=training_image,
                                role=role, 
                                instance_count=1,
                                instance_type='ml.p3.16xlarge',
                                volume_size = 100,
                                max_run = 40000,
                                output_path=s3_output_location,
                                sagemaker_session=sagemaker_session, 
                                hyperparameters = hyperparameters,
                                metric_definitions = metric_definitions,
                               )

Finally, we launch the SageMaker training job.

In [None]:
import time

job_name=f'fairmot-{version_name}-{int(time.time())}'
print(f"Launching Training Job: {job_name}")

# set wait=True below if you want to print logs in cell output
fairmot_estimator.fit(inputs=data_channels, job_name=job_name, logs="All", wait=False)

Check the metrics of the training job in the `Training Job` console.

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
    f'<b><a href="https://console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}">Check the status of training job</a></b>'
    )
)

**Once above training job completed**, we store the S3 URI of the model artifact in IPythonâ€™s database as a variable. This variable will be used to serve model.

In [None]:
s3_model_uri = fairmot_estimator.model_data
%store s3_model_uri

<a id='hyperparametertuning'></a>
## 8.Hyperparameter Tuning
In this step, we define and launch Hyperparameter tuning job. `MaxParallelTrainingJobs` should be <span style="color:red;">**equal or less than the limit of training job instance**</span>. We choose `id_loss` and `lr` for tuning and set `val_loss` to the objective metric. 

As [Best Practices for Hyperparameter Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-considerations.html) suggests, a tuning job improves only through successive rounds of experiments. Therefore, smaller `MaxParallelTrainingJobs` and larger `MaxNumberOfTrainingJobs` may lead to a better result. When `MaxParallelTrainingJobs` is equal to `MaxNumberOfTrainingJobs`, searching strategy will become `Random Search` even setting it as `Bayesian Search`. In this demonstration, we set `MaxParallelTrainingJobs` to 1.

For `MaxNumberOfTrainingJobs`, setting a larger `MaxNumberOfTrainingJobs` cat get the better result, but it takes a longer time. We set `MaxNumberOfTrainingJobs` to the small value 3 to show how SageMaker Hyperparameter works. When you train a model on your own dataset, we recommend to set `MaxNumberOfTrainingJobs` to a larger value.

For more details on Hyperparameter tuning with SageMaker, you can reference [How Hyperparameter Tuning Works](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html).

In [None]:
import json
from time import gmtime, strftime

tuning_job_name = f'fairmot-tuningjob-{version_name}-' + strftime("%d-%H-%M-%S", gmtime())

print(tuning_job_name)

tuning_job_config = {
    "ParameterRanges": {
      "CategoricalParameterRanges": [
          {
              "Name": "id_loss",
              "Values": ['ce', 'focal']
          }
      ],
      "ContinuousParameterRanges": [
        {
          "Name": "lr",
          "MaxValue": "1e-3",
          "MinValue": "1e-5",
          "ScalingType": "Auto"
        }
      ]
    },
    "ResourceLimits": {
      "MaxNumberOfTrainingJobs": 3,
      "MaxParallelTrainingJobs": 1
    },
    "Strategy": "Bayesian",
    "HyperParameterTuningJobObjective": {
      "MetricName": "val_loss",
      "Type": "Minimize"
    }
  }

In [None]:
training_job_definition = {
    "AlgorithmSpecification": {
      "MetricDefinitions": [
            {
                "Name": "train_loss",
                "Regex": "\|train_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_hm_loss",
                "Regex": "\|train_hm_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_wh_loss",
                "Regex": "\|train_wh_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_id_loss",
                "Regex": "\|train_id_loss\\s*(\\S+).*"
            },
            {
                "Name": "train_off_loss",
                "Regex": "\|train_off_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_loss",
                "Regex": "\|val_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_hm_loss",
                "Regex": "\|val_hm_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_wh_loss",
                "Regex": "\|val_wh_loss\\s*(\\S+).*"
            },
            {
                "Name": "val_id_loss",
                "Regex": "\|val_id_loss\\s*(\\S+).*"
            }, 
            {
                "Name": "val_off_loss",
                "Regex": "\|val_off_loss\\s*(\\S+).*"
            }
      ],
      "TrainingImage": training_image,
      "TrainingInputMode": "File"
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": s3train,
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None"
        }
    ],
    "OutputDataConfig": {
      "S3OutputPath": s3_output_location
    },
    "ResourceConfig": {
      "InstanceCount": 1,
      "InstanceType": "ml.p3.16xlarge",
      "VolumeSizeInGB": 100
    },
    "RoleArn": role,
    "StaticHyperParameters": {
        "num_epochs":"20",
        "val_intervals":"1",
        "batch_size":"8",
        "load_model": 'fairmot_dla34.pth',
        "data_name": "MOT17"
        
    },
    "StoppingCondition": {
      "MaxRuntimeInSeconds": 72000
    }
}

Then we launch the defined hyperparameter tuning job.

In [None]:
smclient = boto3.client('sagemaker')
smclient.create_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name,
                                               HyperParameterTuningJobConfig = tuning_job_config,
                                               TrainingJobDefinition = training_job_definition)

In [None]:
smclient.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name)['HyperParameterTuningJobStatus']

Check the status of the hyperparamter tuning job in the `Hyperparameter tuning jobs`console.

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
    f'<b><a href="https://console.aws.amazon.com/sagemaker/home?region={region}#/hyper-tuning-jobs/{tuning_job_name}">Check hyperparameter tuning job</a></b>'
    )
)