In [None]:
# install dependencies
!pip install -Uq sagemaker

# Using Amazon SageMaker Debugger for PyTorch Training Jobs

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

---

Amazon SageMaker is a managed platform to build, train and host machine learning models. Amazon SageMaker Debugger is a new feature which offers capability to debug machine learning and deep learning models during training by identifying and detecting problems with the models in real time.

Amazon SageMaker also gives you the option of bringing your own algorithms packaged in a custom container, that can then be trained and deployed in the Amazon SageMaker environment. 

This notebook guides you through an example of using your own container with PyTorch for training, along with the recently added feature, Amazon SageMaker Debugger.

----

## How does Amazon SageMaker Debugger work?

Amazon SageMaker Debugger lets you go beyond just looking at scalars like losses and accuracies during training and gives you full visibility into all tensors 'flowing through the graph' during training. Furthermore, it helps you monitor your training in real time using rules and CloudWatch events and react to issues like, for example, common training issues such as vanishing gradients or poor weight initialization.

### Concepts

* **Output Tensor**: These are the artifacts that define the state of the training job at any particular instant in its lifecycle.
* **Debug Hook**: Captures the tensors flowing through the training computational graph every N steps.
* **Debugging Rule**: Logic to analyze the tensors captured by the hook and report anomalies.

With these concepts in mind, let's understand the overall flow of things which Amazon SageMaker Debugger uses to orchestrate debugging.

It operates in two steps - saving tensors and analysis.

### Saving tensors

Tensors that debug hook captures are stored in S3 location specified by you. There are two ways you can configure Amazon SageMaker Debugger for storage:

   1. **Zero code change (DEPRECATED for PyTorch versions >= 1.12)**: If you use any of SageMaker provided [Deep Learning containers](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html) then you don't need to make any changes to your training script for tensors to be stored. Amazon SageMaker Debugger will use the configuration you provide in the framework `Estimator` to save tensors in the fashion you specify.
       
       **Note**: In case of PyTorch training, Debugger collects output tensors in GLOBAL mode by default. In other words, this option does not distinguish output tensors from different phases within an epoch, such as training phase and validation phase.
       
   2. **Script change**: Use the SageMaker Debugger client library, SMDebug, and customize training scripts to save the specific tensors you want at different frequencies and configurations. Refer to the [DeveloperGuide](https://github.com/awslabs/sagemaker-debugger/tree/master/docs) for details on how to use SageMaker Debugger with your choice of framework in your training script.
   
In this notebook, we choose the second option to properly save the output tensors from different training phases since we're using PyTorch=1.12

### Analysis of tensors

Once tensors are saved, Amazon SageMaker Debugger can be configured to run debugging ***Rules*** on them. On a very broad level, a rule is a python script used to detect certain conditions during training. Some of the conditions that a data scientist training an algorithm might be interested in are monitoring for gradients getting too large or too small, detecting overfitting, and so on. Amazon SageMaker Debugger comes pre-packaged with certain built-in rules. You can also write your own rules using the Amazon SageMaker Debugger APIs. You can also analyze raw tensor data outside the Rules construct in a notebook, using Amazon SageMaker Debugger's full set of APIs.

----

## Import SageMaker Python SDK and install required packages

In [None]:
import sagemaker

sagemaker.__version__

This notebook works with the SageMaker Python SDK version **2.39.1 or later**.

In [None]:
import pip
import sys

def import_or_install(package):
    try:
        __import__(package)
    except ImportError:
        !{sys.executable} -m pip install {package}
        
required_packages=['smdebug', 'pytest']

for package in required_packages:
    import_or_install(package)

----

## Modify a PyTorch training script

We will focus on how to modify a training script to save tensors by registering debug hooks and specifying which tensors to save.

The model used for this notebook is trained with the MNIST dataset. The example is based on https://github.com/pytorch/examples/blob/master/mnist/main.py (the version as of October 2020).

### Modifying the training script

Before we define a PyTorch estimator and start training, we will explore parts of the training script in detail. (The entire training script can be found at [./scripts/pytorch_mnist.py](./scripts/pytorch_mnist.py)).

- **Step 1**: Import Amazon SageMaker Debugger client library, SMDebug.

    ```python
    import smdebug.pytorch as smd
    
    ```


- **Step 2**: In the `train()` function, add the SMDebug hook for PyTorch with `TRAIN` mode.

    ```python
    hook.set_mode(smd.modes.TRAIN)
    ```


- **Step 3**: In the `test()` function, add the SMDebug hook for PyTorch with `EVAL` mode.

    ```python
    hook.set_mode(smd.modes.EVAL)
    ```


- **Step 4**: In the `main()` function, create the SMDebug hook and register to the model and loss function.

    ```python
    hook = smd.Hook.create_from_json_file()
    hook.register_hook(model)
    hook.register_loss(loss_fn)
    ```


- **Step 4**: In the `main()` function, pass the SMDebug hook to the `train()` and `test()` functions in the epoch loop.

    ```python
    train(args, model, loss_fn, device, train_loader, optimizer, epoch, hook)
    test(model, device, loss_fn, test_loader, hook)
    ```

In [None]:
!pygmentize ./scripts/pytorch_mnist.py

----

## Set up a PyTorch estimator and run a training job

Once these changes are made in the training script, Amazon SageMaker Debugger will start saving tensors during training into a specified output S3 bucket.

Now, we will set up the estimator and start training using the modified training script.

In [None]:
from __future__ import absolute_import

import boto3
import pytest
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role
from sagemaker.debugger import (
    Rule,
    ProfilerRule,
    DebuggerHookConfig,
    TensorBoardOutputConfig,
    CollectionConfig,
    rule_configs,
)

Define the configuration of training to run. `ecr_image` is where you can provide link to your bring-your-own-container. `hyperparameters` are fed into the training script with data directory (directory where the training dataset is stored) and smdebug directory (directory where the tensors will be saved) are mandatory fields.

In [None]:
hyperparameters = {"epochs": "5", "batch-size": "32", "test-batch-size": "100", "lr": "0.001"}

### Configure a Debugger rule object

The `rules` parameter is a new parameter that accepts a list of rules against output tensors that you want to evaluate.

In this example, we use the following Debugger rules that will attempt to evaluate if there are overfit, overtraining, and vanishing gradients problems.

In [None]:
rules = [
    Rule.sagemaker(rule_configs.vanishing_gradient()),
    Rule.sagemaker(rule_configs.overfit()),
    Rule.sagemaker(rule_configs.overtraining()),
    Rule.sagemaker(rule_configs.poor_weight_initialization()),
    ProfilerRule.sagemaker(rule_configs.ProfilerReport()),
]

For more information about the rules, see the following documentation.

- [Vanishing gradient](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html#vanishing-gradient)
- [Overfit](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html#overfit)
- [Overtraining](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html#overtraining)
- [Poor weight initialization](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html#poor-weight-initialization)

In addition to the model debugging rules above, SageMaker Debugger runs the ProfilerReport rule by default. This runs rules for system bottleneck detections and autogenerates a profiling report. For more information, see the following documentation:

- [SageMaker Debugger Profiling Report](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-profiling-report.html)
- [ProfilerReport rule](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html#profiler-report)

### Configure Debugger hook parameters

The following code shows how to adjust save intervals of the output tensors in the different training phases.

In [None]:
hook_config = DebuggerHookConfig(
    hook_parameters={"train.save_interval": "100", "eval.save_interval": "10"}
)

### Construct a PyTorch estimator with the Debugger parameters

In [None]:
estimator = PyTorch(
    entry_point="scripts/pytorch_mnist.py",
    base_job_name="smdebugger-demo-mnist-pytorch",
    role=get_execution_role(),
    instance_count=1,
    instance_type="ml.p2.xlarge",
    volume_size=400,
    max_run=3600,
    hyperparameters=hyperparameters,
    framework_version="1.8",
    py_version="py36",
    ## Debugger parameters
    rules=rules,
    debugger_hook_config=hook_config,
)

### Start the training job

In [None]:
estimator.fit(wait=True)

----

## Check SageMaker Debugger rule summaries

As a result of calling the `fit()` method, Amazon SageMaker Debugger starts a rule evaluation job to monitor `vanishing_gradient()`, `overfit()`, and `overtraining()` issues in parallel with the training job. 

The `ProfilerReport` rule runs for all SageMaker training jobs by default. You will be able to receive a comprehensive training report regarding system bottlenecks and framework profiling.

### Print the latest training job's rule summary in real time

In [None]:
job_name = estimator.latest_training_job.name
client = estimator.sagemaker_session.sagemaker_client
description = client.describe_training_job(TrainingJobName=estimator.latest_training_job.name)

In [None]:
import time
from IPython import display

%matplotlib inline

while description["SecondaryStatus"] not in {"Stopped", "Completed"}:
    description = client.describe_training_job(TrainingJobName=job_name)
    primary_status = description["TrainingJobStatus"]
    secondary_status = description["SecondaryStatus"]
    print("====================================================================")
    print("TrainingJobStatus: ", primary_status, " | SecondaryStatus: ", secondary_status)
    print("====================================================================")
    for r in range(len(estimator.latest_training_job.rule_job_summary())):
        rule_summary = estimator.latest_training_job.rule_job_summary()
        print(
            rule_summary[r]["RuleConfigurationName"], ": ", rule_summary[r]["RuleEvaluationStatus"]
        )
        if rule_summary[r]["RuleEvaluationStatus"] == "IssuesFound":
            print(rule_summary[r]["StatusDetails"])
        print("====================================================================")
    print("Current time: ", time.asctime())
    display.clear_output(wait=True)
    time.sleep(100)

### Print URLs to the corresponding processing job logs in CloudWatch

In [None]:
def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
    """Helper function to get the rule job name with correct casing"""
    return "{}-{}-{}".format(
        training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
    )


def _get_cw_url_for_rule_job(rule_job_name, region):
    return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(
        region, region, rule_job_name
    )


def get_rule_jobs_cw_urls(estimator):
    region = boto3.Session().region_name
    training_job = estimator.latest_training_job
    training_job_name = training_job.describe()["TrainingJobName"]
    rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]

    result = {}
    for status in rule_eval_statuses:
        if status.get("RuleEvaluationJobArn", None) is not None:
            rule_job_name = _get_rule_job_name(
                training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"]
            )
            result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(
                rule_job_name, region
            )
    return result


get_rule_jobs_cw_urls(estimator)

----

## SageMaker Debugger reports and analysis

Another aspect of the Amazon SageMaker Debugger is analysis. It allows us to perform interactive exploration of the tensors saved in real time or after the job. Here we focus on after-the-fact analysis of the above job. We import the smdebug library, which defines a concept of Trial that represents a single training run. Note how we fetch the path to debugger artifacts for the above job.

### Create an SMDebug trial object and retrieve saved output tensors

In [None]:
from smdebug.trials import create_trial
from smdebug.core.modes import ModeKeys

trial = create_trial(estimator.latest_job_debugger_artifacts_path())

### Print check what output tensors are saved

We can list all the tensors that were recorded to know what we want to plot.

In [None]:
trial.tensor_names()

We can also retrieve tensors by some default collections that smdebug creates from your training job. Here we are interested in the losses collection, so we can retrieve the names of tensors in losses collection as follows. Amazon SageMaker Debugger creates default collections such as weights, gradients, biases, losses automatically. You can also create custom collections from your tensors.

### Check the number of steps saved in the different training phases

In [None]:
len(trial.tensor("NLLLoss_output_0").steps(mode=ModeKeys.TRAIN))

In [None]:
len(trial.tensor("NLLLoss_output_0").steps(mode=ModeKeys.EVAL))

### Set up functions to log and plot the output tensors

In [None]:
def get_data(trial, tname, mode):
    tensor = trial.tensor(tname)
    steps = tensor.steps(mode=mode)
    vals = []
    for s in steps:
        vals.append(tensor.value(s, mode=mode))
    return steps, vals

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import host_subplot


def plot_tensor(trial, tensor_name):
    steps_train, vals_train = get_data(trial, tensor_name, mode=ModeKeys.TRAIN)
    print("loaded TRAIN data")
    steps_eval, vals_eval = get_data(trial, tensor_name, mode=ModeKeys.EVAL)
    print("loaded EVAL data")

    fig = plt.figure(figsize=(10, 7))
    host = host_subplot(111)

    par = host.twiny()

    host.set_xlabel("Steps (TRAIN)")
    par.set_xlabel("Steps (EVAL)")
    host.set_ylabel(tensor_name)

    (p1,) = host.plot(steps_train, vals_train, label=tensor_name)
    print("completed TRAIN plot")
    (p2,) = par.plot(steps_eval, vals_eval, label="val_" + tensor_name)
    print("completed EVAL plot")
    leg = plt.legend()

    host.xaxis.get_label().set_color(p1.get_color())
    leg.texts[0].set_color(p1.get_color())

    par.xaxis.get_label().set_color(p2.get_color())
    leg.texts[1].set_color(p2.get_color())

    plt.ylabel(tensor_name)

    plt.show()

In [None]:
plot_tensor(trial, "NLLLoss_output_0")

### Reflect the rule summary report

Recall what the rule summary reported:

```
Overfit :  IssuesFound
RuleEvaluationConditionMet: Evaluation of the rule Overfit at step 4000 resulted in the condition being met
```

Based on this rule evaluation and the plot above, we can conclude that the training job has an overfit issue. While the `NLLLoss_output_0` line is decreasing, the `val_NLLLoss_output_0` line is fluctuating and not decreasing. 

To resolve the overfit problem, you need to consider using or double-checking the following techniques:

- Regularization
- Weight initialization
- Dropout regularization
- Weight constraints

### Download, open, and display the ProfilerReport HTML file

In [None]:
rule_output_path = estimator.output_path + estimator.latest_training_job.job_name + "/rule-output"

In [None]:
! aws s3 ls {rule_output_path} --recursive

In [None]:
! aws s3 cp {rule_output_path} ./ --recursive

In [None]:
import os

# get the autogenerated folder name of profiler report
profiler_report_name = [
    rule["RuleConfigurationName"]
    for rule in estimator.latest_training_job.rule_job_summary()
    if "Profiler" in rule["RuleConfigurationName"]
][0]

In [None]:
import IPython

IPython.display.HTML(filename=profiler_report_name + "/profiler-output/profiler-report.html")

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker-debugger|pytorch_model_debugging|pytorch_script_change_smdebug.ipynb)
