# ByteTrack Training with Amazon SageMaker

This notebook demonstrates how to train a [ByteTrack](https://arxiv.org/abs/2110.06864) model with SageMaker and tune hyper-parameters with [SageMaker Hyperparameter tuning job](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html). We will use the dataset labeled by SageMaker Ground Truth in [data-preparation.ipynb](data-preparation.ipynb).

## 1. SageMaker Initialization 
First we upgrade SageMaker to the latest version. If your notebook is already using the latest SageMaker 2.x API, you may skip the next cell.

In [None]:
! pip install --upgrade pip
! python3 -m pip install --upgrade sagemaker

In [None]:
import datetime
datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")

In [None]:
%%time
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.estimator import Estimator

from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.processing import SKLearnProcessor

role = get_execution_role() # provide a pre-existing role ARN as an alternative to creating a new role
print(f'SageMaker Execution Role:{role}')

client = boto3.client('sts')
account = client.get_caller_identity()['Account']
print(f'AWS account:{account}')

session = boto3.session.Session()
region = session.region_name
print(f'AWS region:{region}')

In [None]:
s3_bucket = <your-s3-bucket-name> # your-s3-bucket-name

dataset_name= <your-dataset-name> # dataset name

prefix = <prefix>

training_image = f"{account}.dkr.ecr.{region}.amazonaws.com/bytetrack-sagemaker:pytorch1.12.1"

## 2. Prepare dataset

We prepare our dataset as below:
- Convert SageMaker Ground Truth annotation into MOT Challenge annotation
- Convert MOT annotation into MSCOCO annotation

As we keep both MOT format dataset and MSCOCO format dataset, you can train other MOT algorithms without separating detection and tracking such as [FairMOT](https://arxiv.org/abs/2004.01888) on MOT format dataset. In addition, You can easily change the detector to other algorithms such as YOLO7 to leverage your existing object detection algorithm.

### 2.1 Convert SageMaker Ground Truth annotation into MOT Challenge annotation

`sm_gt_uri` is the label data from SageMaker Ground Truth, and `data_uri` is the video frames from the original video file.

In [None]:
data_prefix = f"{prefix}/sample-data"

sm_gt_uri = f"s3://{s3_bucket}/{prefix}/mot-bytetrack-sample/" # annotation data
mot_uri = f"s3://{s3_bucket}/{prefix}/outputs-mot/" # output data with MOT format
data_uri = f"s3://{s3_bucket}/{data_prefix}" # Video frame

By default, if your input data for Tracking task is video file, SageMaker Ground Truth can extract frames from video files, and save these frames in the same directory as video file with the folder name of video name. In the ground truth convet processing, we need to copy this input directory into the instance of SageMaker Processing, which doesn't allow file and folder has the same name in the same directory, therefore we need to delete the original video files.

In [None]:
s3_client = boto3.client('s3')

def remove_video_from_s3():
    ContinuationToken = None
    obj_list = []
    while True:
        if ContinuationToken:
            response = s3_client.list_objects_v2(
                Bucket=s3_bucket,
                Prefix=data_prefix,
                ContinuationToken=ContinuationToken
            )
        else:
            response = s3_client.list_objects_v2(
                Bucket=s3_bucket,
                Prefix=data_prefix
            )
        obj_list += response['Contents']

        if 'NextContinuationToken' in response:
            ContinuationToken = response['NextContinuationToken']
        else:
            break
        
    for content in obj_list:
        if content['Key'].endswith('.mp4'):
            s3_client.delete_object(Bucket=s3_bucket, Key=content['Key'])

remove_video_from_s3()

Create a SageMaker Processing job to convert annotation

In [None]:
labeling_job_name = <your labeling job name>

In [None]:
sklearn_processor = SKLearnProcessor(
    framework_version="1.0-1",
    role=role,
    instance_type="ml.m5.xlarge",
    instance_count=1
)

sklearn_processor.run(
    code="gt_processing/convert_gt.py",
    inputs=[
        ProcessingInput(source=sm_gt_uri, destination="/opt/ml/processing/gt"),
        ProcessingInput(source=data_uri, destination="/opt/ml/processing/input")
    ],
    outputs=[
        ProcessingOutput(source="/opt/ml/processing/output", destination=mot_uri)
    ],
    arguments=[
        "--train-ratio", "0.6",
        "--val-ratio", "0.2",
        "--test-ratio", "0.2",
        "--labeling-job-name", labeling_job_name
    ],
)

Once we finish this task, a MOT17 annotation dataset will be created in the defined S3 bucket.

### 2.2 Convert MOT annotation into MSCOCO annotation

ByteTrack uses [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) to do detection task and then run tracking, and YOLOX uses the MSCOCO annotation dataset to train a model. Therefore we need to convert MOT annotation into MSCOCO annotation.

In [None]:
coco_uri = f"s3://{s3_bucket}/{data_prefix}/{dataset_name}/" # Output data with COCO format

In [None]:
mot2coco_processor = SKLearnProcessor(
    framework_version="1.0-1",
    role=role,
    instance_type="ml.m5.xlarge",
    instance_count=1
)

mot2coco_processor.run(
    code="gt_processing/mot_to_coco.py",
    inputs=[
        ProcessingInput(source=mot_uri, destination="/opt/ml/processing/mot")
    ],
    outputs=[
        ProcessingOutput(source="/opt/ml/processing/coco", destination=coco_uri)
    ]
)

## 3. Build and push SageMaker training image
We use the implementation of [ByteTrack](https://github.com/ifzhang/ByteTrack) to create our own container, and push the image to [Amazon ECR](https://aws.amazon.com/ecr/). For more details about how to use BYOC on SageMaker, please refer to [Adapting your own training container](https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-training-container.html).

### Docker Environment Preparation
Because the volume size of container may be larger than the available size in root directory of the notebook instance, we need to put the directory of docker data into the ```/home/ec2-user/SageMaker/docker``` directory.

By default, the root directory of docker is set as ```/var/lib/docker/```. We need to change the directory of docker to ```/home/ec2-user/SageMaker/docker```.

In [None]:
!cat /etc/docker/daemon.json

In [None]:
!bash ./prepare-docker.sh

### Build training image for ByteTrack
Use script [`./container/build_tools/build_and_push.sh`](./container-dp/build_tools/build_and_push.sh) to build and push the ByteTrack training image to [Amazon ECR](https://aws.amazon.com/ecr/).

In [None]:
!cat ./container-{version_name}/build_tools/build_and_push.sh

Using your *AWS region* as argument, run the cell below.

In [None]:
%%time
!bash ./container-train/build_tools/build_and_push.sh {region}

## 4. Define SageMaker Data Channels
In this step, we define SageMaker `train` data channel.

Go to [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX/tree/0.1.0), and download a pretrained model (YOLOX-x) from Standard Models. Then upload pretrained model to S3 Bucket.

In [None]:
pretrain_model_s3uri = f's3://{s3_bucket}/{prefix}/pretrained-models'

In [None]:
!wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth
!aws s3 cp yolox_x.pth $pretrain_model_s3uri/yolox_x.pth
!rm yolox_x.pth

In [None]:
from sagemaker.inputs import TrainingInput

mot_input = TrainingInput(s3_data=coco_uri, 
                            distribution="FullyReplicated", 
                            s3_data_type='S3Prefix', 
                            input_mode='File')

pretrain_input = TrainingInput(s3_data=pretrain_model_s3uri, 
                            distribution="FullyReplicated", 
                            s3_data_type='S3Prefix', 
                            input_mode='File')

data_channels = {'mot': mot_input, 'pretrain': pretrain_input}

Next, we define the model output location in S3 bucket.

In [None]:
s3_output_location = f's3://{s3_bucket}/{prefix}/output/'

## 5. Configure Hyper-parameters
In this step, we define the hyper-parameters for Training ByteTrack. Jump to [8.Hyperparameter Tuning](#hyperparametertuning) if you want to run hyperparameter tuning job.

<table align='left'>
    <caption>ByteTrack Hyper-parameters</caption>
    <tr>
    <th style="text-align:center">Hyper-parameter</th>
    <th style="text-align:center">Description</th>
    <th style="text-align:center">Default</th>
    </tr>
     <tr>
        <td style="text-align:center">fp16</td>
        <td style="text-align:left">Adopting mix precision training</td>
        <td style="text-align:center">0 or 1</td>
    </tr>
    <tr>
        <td style="text-align:center">batch_size</td>
        <td style="text-align:left">Batch size, should be larger than number_gpu by 2</td>
        <td style="text-align:center">24</td>
    </tr>
    <tr>
        <td style="text-align:center">dataset_name</td>
        <td style="text-align:left">Assume there are several datasets, choose one dataset you want to train</td>
        <td style="text-align:center">'mot'</td>
    </tr>
    <tr>
        <td style="text-align:center">occupy</td>
        <td style="text-align:left">occupy GPU memory first for training, true by default.</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">pretrained_model</td>
        <td style="text-align:left">Pretrained model we want to use</td>
        <td style="text-align:center">`yolox_x.pth`</td>
    </tr>
    <tr>
        <td style="text-align:center">num_classes</td>
        <td style="text-align:left">number of classes</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">depth</td>
        <td style="text-align:left">depth</td>
        <td style="text-align:center">1.33</td>
    </tr>
    <tr>
        <td style="text-align:center">width</td>
        <td style="text-align:left">width</td>
        <td style="text-align:center">1.25</td>
    </tr>
    <tr>
        <td style="text-align:center">input_size_h</td>
        <td style="text-align:left">height in input size</td>
        <td style="text-align:center">800</td>
    </tr>
    <tr>
        <td style="text-align:center">input_size_w</td>
        <td style="text-align:left">width in input size</td>
        <td style="text-align:center">1440</td>
    </tr>
    <tr>
        <td style="text-align:center">test_size_h</td>
        <td style="text-align:left">height in test size</td>
        <td style="text-align:center">800</td>
    </tr>
    <tr>
        <td style="text-align:center">test_size_w</td>
        <td style="text-align:left">width in test size</td>
        <td style="text-align:center">1440</td>
    </tr>
    <tr>
        <td style="text-align:center">random_size_h</td>
        <td style="text-align:left">height in random size</td>
        <td style="text-align:center">18</td>
    </tr>
    <tr>
        <td style="text-align:center">random_size_w</td>
        <td style="text-align:left">width in random size</td>
        <td style="text-align:center">32</td>
    </tr>
    <tr>
        <td style="text-align:center">max_epoch</td>
        <td style="text-align:left">max epoch</td>
        <td style="text-align:center">80</td>
    </tr>
    <tr>
        <td style="text-align:center">print_interval</td>
        <td style="text-align:left">print_interval</td>
        <td style="text-align:center">20</td>
    </tr>
    <tr>
        <td style="text-align:center">eval_interval</td>
        <td style="text-align:left">eval_interval</td>
        <td style="text-align:center">5</td>
    </tr>
    <tr>
        <td style="text-align:center">test_conf</td>
        <td style="text-align:left">confidence threshold ranging from 0 to 1</td>
        <td style="text-align:center">0.001</td>
    </tr>
    <tr>
        <td style="text-align:center">nmsthre</td>
        <td style="text-align:left">IoU threshold of non-max supression ranging from 0 to 1</td>
        <td style="text-align:center">0.7</td>
    </tr>
    <tr>
        <td style="text-align:center">basic_lr_per_img</td>
        <td style="text-align:left">use a learning rate of lr×BatchSize/64, with a initial lr = 0.01 and the cosine lr schedule</td>
        <td style="text-align:center">0.001/64.0</td>
    </tr>
    <tr>
        <td style="text-align:center">no_aug_epochs</td>
        <td style="text-align:left">no_aug_epochs</td>
        <td style="text-align:center">10</td>
    </tr>
    <tr>
        <td style="text-align:center">warmup_epochs</td>
        <td style="text-align:left">warmup_epochs</td>
        <td style="text-align:center">1</td>
    </tr>
    <tr>
        <td style="text-align:center">infer_device</td>
        <td style="text-align:left">device type for inference</td>
        <td style="text-align:center">'cuda'</td>
    </tr>
</table>

In [None]:
hyperparameters = {
                    "batch_size": 24,
                    "max_epoch": 30,
                    "val_intervals": 1,
                    "pretrained_model": "yolox_x.pth",
                    "fp16": 0,
                    "infer_device": "cuda"
                  }

## 6. Define Training Metrics
Next, we define the regular expressions that SageMaker uses to extract algorithm metrics from training logs and send them to [AWS CloudWatch metrics](https://docs.aws.amazon.com/en_pv/AmazonCloudWatch/latest/monitoring/working_with_metrics.html). These algorithm metrics are visualized in SageMaker console.

In [None]:
metric_definitions=[
            {
                "Name": "total_loss",
                "Regex": "total_loss: (.*?),"
            },
            {
                "Name": "iou_loss",
                "Regex": "iou_loss: (.*?),"
            },
            {
                "Name": "l1_loss",
                "Regex": "l1_loss: (.*?),"
            },
            {
                "Name": "conf_loss",
                "Regex": "conf_loss: (.*?),"
            },
            {
                "Name": "cls_loss",
                "Regex": "cls_loss: (.*?),"
            },
            {
                "Name": "lr",
                "Regex": "lr: (.*?),"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.75 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.75 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=small | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= small \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=medium | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area=medium \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=large | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= large \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=1)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets= 1 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=10)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets= 10 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=small | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= small \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=medium | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area=medium \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=large | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= large \| maxDets=100 \] = ([0-9\.]+)"
            }
    ]

## 7. Define SageMaker Training Job

Next, we use SageMaker [Estimator](https://sagemaker.readthedocs.io/en/stable/estimators.html) API to define a SageMaker Training Job.

Multi-GPU instance is not required in this solution, you can choose any other GPU instance for training, <span style="color:red">note that you need to adjust batch size based on the GPU memory to avoid the error of out of memory<span>.

In [None]:
sagemaker_session = sagemaker.session.Session(boto_session=session)

bytetrack_estimator = Estimator(image_uri=training_image,
                                role=role, 
                                instance_count=1,
                                instance_type='ml.p3.16xlarge',
                                volume_size = 100,
                                max_run = 40000,
                                output_path=s3_output_location,
                                sagemaker_session=sagemaker_session, 
                                hyperparameters = hyperparameters,
                                metric_definitions = metric_definitions,
                               )

Finally, we launch the SageMaker training job.

In [None]:
import time

job_name=f'bytetrack-{version_name}-{int(time.time())}'
print(f"Launching Training Job: {job_name}")

# set wait=True below if you want to print logs in cell output
bytetrack_estimator.fit(inputs=data_channels, job_name=job_name, logs="All", wait=False)

Check the metrics of the training job in the `Training Job` console.

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
    f'<b><a href="https://console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}">Check the status of training job</a></b>'
    )
)

**Once above training job completed**, we store the S3 URI of the model artifact in IPython’s database as a variable. This variable will be used to serve model.

In [None]:
s3_model_uri = bytetrack_estimator.model_data
%store s3_model_uri

<a id='hyperparametertuning'></a>
## 8.Hyperparameter Tuning
In this step, we define and launch Hyperparameter tuning job. `MaxParallelTrainingJobs` should be <span style="color:red;">**equal or less than the limit of training job instance**</span>. We choose `depth` and `width` for tuning and set `total_loss` to the objective metric. 

As [Best Practices for Hyperparameter Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-considerations.html) suggests, a tuning job improves only through successive rounds of experiments. Therefore, smaller `MaxParallelTrainingJobs` and larger `MaxNumberOfTrainingJobs` may lead to a better result. When `MaxParallelTrainingJobs` is equal to `MaxNumberOfTrainingJobs`, searching strategy will become `Random Search` even setting it as `Bayesian Search`. In this demonstration, we set `MaxParallelTrainingJobs` to 1.

For `MaxNumberOfTrainingJobs`, setting a larger `MaxNumberOfTrainingJobs` cat get the better result, but it takes a longer time. We set `MaxNumberOfTrainingJobs` to the small value 3 to show how SageMaker Hyperparameter works. When you train a model on your own dataset, we recommend to set `MaxNumberOfTrainingJobs` to a larger value.

For more details on Hyperparameter tuning with SageMaker, you can reference [How Hyperparameter Tuning Works](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html).

In [None]:
import json
from time import gmtime, strftime

tuning_job_name = f'bytetrack-tuningjob-' + strftime("%d%H%M%S", gmtime())

print(tuning_job_name)

tuning_job_config = {
    "ParameterRanges": {
      "ContinuousParameterRanges": [
        {
          "Name": "basic_lr_per_img",
          "MaxValue": "0.000016625",
          "MinValue": "0.000013625",
          "ScalingType": "Auto"
        }
      ]
    },
    "ResourceLimits": {
      "MaxNumberOfTrainingJobs": 3,
      "MaxParallelTrainingJobs": 1
    },
    "Strategy": "Bayesian",
    "HyperParameterTuningJobObjective": {
      "MetricName": "total_loss",
      "Type": "Minimize"
    }
  }

In [None]:
training_job_definition = {
    "AlgorithmSpecification": {
      "MetricDefinitions": [
            {
                "Name": "total_loss",
                "Regex": "total_loss: (.*?),"
            },
            {
                "Name": "iou_loss",
                "Regex": "iou_loss: (.*?),"
            },
            {
                "Name": "l1_loss",
                "Regex": "l1_loss: (.*?),"
            },
            {
                "Name": "conf_loss",
                "Regex": "conf_loss: (.*?),"
            },
            {
                "Name": "cls_loss",
                "Regex": "cls_loss: (.*?),"
            },
            {
                "Name": "lr",
                "Regex": "lr: (.*?),"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.75 | area=all | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.75 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=small | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= small \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=medium | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area=medium \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AP(IoU=0.50:0.95 | area=large | maxDets=100)",
                "Regex": "AP\) @\[ IoU=0.50:0.95 \| area= large \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=1)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets= 1 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=10)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets= 10 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=all | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= all \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=small | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= small \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=medium | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area=medium \| maxDets=100 \] = ([0-9\.]+)"
            },
            {
                "Name": "AR(IoU=0.50:0.95 | area=large | maxDets=100)",
                "Regex": "AR\) @\[ IoU=0.50:0.95 \| area= large \| maxDets=100 \] = ([0-9\.]+)"
            }
      ],
      "TrainingImage": training_image,
      "TrainingInputMode": "File"
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": s3train,
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None"
        }
    ],
    "OutputDataConfig": {
      "S3OutputPath": s3_output_location
    },
    "ResourceConfig": {
      "InstanceCount": 1,
      "InstanceType": "ml.p3.16xlarge",
      "VolumeSizeInGB": 100
    },
    "RoleArn": role,
    "StaticHyperParameters": {
        "batch_size": "24",
        "max_epoch": "10",
        "val_intervals": "1",
        "pretrained_model": "yolox_x.pth",
        "fp16": "0",
        "infer_device": "cuda"
    },
    "StoppingCondition": {
      "MaxRuntimeInSeconds": 72000
    }
}

Then we launch the defined hyperparameter tuning job.

In [None]:
smclient = boto3.client('sagemaker')
smclient.create_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name,
                                               HyperParameterTuningJobConfig = tuning_job_config,
                                               TrainingJobDefinition = training_job_definition)

In [None]:
smclient.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name)['HyperParameterTuningJobStatus']

Check the status of the hyperparamter tuning job in the `Hyperparameter tuning jobs`console.

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
    f'<b><a href="https://console.aws.amazon.com/sagemaker/home?region={region}#/hyper-tuning-jobs/{tuning_job_name}">Check hyperparameter tuning job</a></b>'
    )
)