# Train a PyTorch model with MNIST dataset


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

---


The notebook shows how to use the @remote and RemoteExecutor introduced to SageMaker SDK
to train Pytorch models with remote jobs and track the training as SageMaker experiments.

## Install the dependencies

In [None]:
%pip install -r ./requirements.txt

In [None]:
import sagemaker
from sagemaker.experiments.run import Run, load_run
from sagemaker.remote_function import remote, RemoteExecutor

sm_session = sagemaker.Session()
s3_root_folder = f"s3://{sm_session.default_bucket()}/remote_function_demo/pytorch_mnist"

## Load the MNIST data

Download data to ./data folder, load and normalize them.

In [None]:
from torchvision import datasets, transforms


datasets.MNIST.mirrors = [
 f"https://sagemaker-example-files-prod-{sm_session.boto_region_name}.s3.amazonaws.com/datasets/image/MNIST/"
]

train_set = datasets.MNIST(
 "./data",
 train=True,
 transform=transforms.Compose(
 [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
 ),
 download=True,
)

test_set = datasets.MNIST(
 "./data",
 train=False,
 transform=transforms.Compose(
 [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
 ),
 download=True,
)

## Define the model architecture and training logic

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
 """Define the CNN architecture."""

 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = nn.Conv2d(1, 32, 3, 1)
 self.conv2 = nn.Conv2d(32, 64, 3, 1)
 self.dropout1 = nn.Dropout(0.25)
 self.dropout2 = nn.Dropout(0.5)
 self.fc1 = nn.Linear(9216, 128)
 self.fc2 = nn.Linear(128, 10)

 def forward(self, x):
 x = self.conv1(x)
 x = F.relu(x)
 x = self.conv2(x)
 x = F.relu(x)
 x = F.max_pool2d(x, 2)
 x = self.dropout1(x)
 x = torch.flatten(x, 1)
 x = self.fc1(x)
 x = F.relu(x)
 x = self.dropout2(x)
 x = self.fc2(x)
 output = F.log_softmax(x, dim=1)
 return output

In [None]:
# Train the model by iterating through the data once
# If dry_run is set to True, it only goes through one batch of the data set
def train(model, device, train_loader, optimizer, epoch, log_interval, dry_run):
 model.train()
 for batch_idx, (data, target) in enumerate(train_loader):
 data, target = data.to(device), target.to(device)
 optimizer.zero_grad()
 output = model(data)
 loss = F.nll_loss(output, target)
 loss.backward()
 optimizer.step()
 if batch_idx % log_interval == 0:
 print(
 "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
 epoch,
 batch_idx * len(data),
 len(train_loader.dataset),
 100.0 * batch_idx / len(train_loader),
 loss.item(),
 )
 )
 if dry_run:
 break

In [None]:
# Test the trained model using the test data set and also log test metrics to SageMaker experiment runs.
def check_performance(run, model, device, test_loader, epoch):
 model.eval()
 test_loss = 0
 correct = 0
 with torch.no_grad():
 for data, target in test_loader:
 data, target = data.to(device), target.to(device)
 output = model(data)
 test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
 pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
 correct += pred.eq(target.view_as(pred)).sum().item()

 test_loss /= len(test_loader.dataset)
 test_accuracy = 100.0 * correct / len(test_loader.dataset)

 # log metrics
 run.log_metric(name="test:loss", value=test_loss, step=epoch)
 run.log_metric(name="test:accuracy", value=test_accuracy, step=epoch)

 print(
 "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
 test_loss, correct, len(test_loader.dataset), test_accuracy
 )
 )

In [None]:
# Train the model for specified number of epochs and test the performance after each epoch.
def perform_train(
 train_data,
 test_data,
 *,
 batch_size: int = 64,
 test_batch_size: int = 1000,
 epochs: int = 3,
 lr: float = 1.0,
 gamma: float = 0.7,
 no_cuda: bool = True,
 no_mps: bool = True,
 dry_run: bool = False,
 seed: int = 1,
 log_interval: int = 10,
):
 """PyTorch MNIST Example

 :param train_data: the training data set
 :param test_data: the test data set
 :param batch_size: input batch size for training (default: 64)
 :param test_batch_size: input batch size for testing (default: 1000)
 :param epochs: number of epochs to train (default: 14)
 :param lr: learning rate (default: 1.0)
 :param gamma: Learning rate step gamma (default: 0.7)
 :param no_cuda: disables CUDA training
 :param no_mps: disables macOS GPU training
 :param dry_run: quickly check a single pass
 :param seed: random seed (default: 1)
 :param log_interval: how many batches to wait before logging training status
 :return: the trained model
 """

 use_cuda = not no_cuda and torch.cuda.is_available()
 use_mps = not no_mps and torch.backends.mps.is_available()

 torch.manual_seed(seed)

 if use_cuda:
 device = torch.device("cuda")
 elif use_mps:
 device = torch.device("mps")
 else:
 device = torch.device("cpu")

 train_kwargs = {"batch_size": batch_size}
 test_kwargs = {"batch_size": test_batch_size}
 if use_cuda:
 cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
 train_kwargs.update(cuda_kwargs)
 test_kwargs.update(cuda_kwargs)

 train_loader = torch.utils.data.DataLoader(train_data, **train_kwargs)
 test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)

 model = Net().to(device)
 optimizer = optim.Adadelta(model.parameters(), lr=lr)

 scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

 # load the experiment run from the context
 with load_run() as run:
 run.log_parameters({"epochs": epochs, "lr": lr, "gamma": gamma})

 for epoch in range(1, epochs + 1):
 train(model, device, train_loader, optimizer, epoch, log_interval, dry_run)
 check_performance(run, model, device, test_loader, epoch)
 scheduler.step()

 # log confusion matrix
 with torch.no_grad():
 data, target = next(iter(test_loader))
 data, target = data.to(device), target.to(device)
 output = model(data)
 pred = output.max(1, keepdim=True)[1]
 run.log_confusion_matrix(target, pred, "confusion-matrix-test-data")

 return model

## Execute and test the training function locally

Set the dry_run to `True` and execute the function locally to verify the correctness of the code and dependencies.

In [None]:
with Run(
 experiment_name="local-tests", run_name="local-tests", sagemaker_session=sm_session
) as run:
 trained_model = perform_train(train_set, test_set, epochs=1, dry_run=True)

## Setup Configuration file path
We are setting the directory in which the config.yaml file resides so that remote decorator can make use of the settings.

In [None]:
import os

# Set path to config file
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = os.getcwd()

## Execute the function remotely in parallel

Here we run a series of jobs to explore how the hyperparameter `lr` impact the performance.
The argument `max_parallel_jobs` controls the max number of jobs that can run in parallel.
We also take advantage of warm pooling to reduce the job start up time.

In [None]:
from sagemaker.remote_function import RemoteExecutor

with RemoteExecutor(
 max_parallel_jobs=2, keep_alive_period_in_seconds=60, s3_root_uri=s3_root_folder
) as executor:
 futures = {}
 for run_name, lr in [("run-0", 0.3), ("run-1", 1), ("run-2", 3.0)]:
 with Run(
 experiment_name="pytorch-mnist", run_name=run_name, sagemaker_session=sm_session
 ) as run:
 run.log_artifact(
 name="raw_data",
 value=f"https://sagemaker-example-files-prod-{sm_session.boto_region_name}.s3.amazonaws.com/datasets/image/MNIST/",
 is_output=False,
 )
 futures[run_name] = executor.submit(perform_train, train_set, test_set, lr=lr)

### Go to SageMaker Studio to view the experiment runs

![charts of run-0](./images/experiment_snapshot.png)

In [None]:
trained_models = [future.result() for future in futures.values()]

## Execute the function remotely with function decorator

Note that the duplication of the `perform_train` function implementation is only for demonstration purpose. `@remote` can be applied to the original `perform_train` once the local test runs successfully.

In [None]:
from sagemaker.remote_function import remote


@remote
def perform_train(
 train_data,
 test_data,
 *,
 batch_size: int = 64,
 test_batch_size: int = 1000,
 epochs: int = 3,
 lr: float = 1.0,
 gamma: float = 0.7,
 no_cuda: bool = True,
 no_mps: bool = True,
 dry_run: bool = False,
 seed: int = 1,
 log_interval: int = 10,
):
 """PyTorch MNIST Example

 :param train_data: the training data set
 :param test_data: the test data set
 :param batch_size: input batch size for training (default: 64)
 :param test_batch_size: input batch size for testing (default: 1000)
 :param epochs: number of epochs to train (default: 14)
 :param lr: learning rate (default: 1.0)
 :param gamma: Learning rate step gamma (default: 0.7)
 :param no_cuda: disables CUDA training
 :param no_mps: disables macOS GPU training
 :param dry_run: quickly check a single pass
 :param seed: random seed (default: 1)
 :param log_interval: how many batches to wait before logging training status
 :return: the trained model
 """

 use_cuda = not no_cuda and torch.cuda.is_available()
 use_mps = not no_mps and torch.backends.mps.is_available()

 torch.manual_seed(seed)

 if use_cuda:
 device = torch.device("cuda")
 elif use_mps:
 device = torch.device("mps")
 else:
 device = torch.device("cpu")

 train_kwargs = {"batch_size": batch_size}
 test_kwargs = {"batch_size": test_batch_size}
 if use_cuda:
 cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
 train_kwargs.update(cuda_kwargs)
 test_kwargs.update(cuda_kwargs)

 train_loader = torch.utils.data.DataLoader(train_data, **train_kwargs)
 test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)

 model = Net().to(device)
 optimizer = optim.Adadelta(model.parameters(), lr=lr)

 scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

 # load the experiment run from the context
 with load_run() as run:
 run.log_parameters({"epochs": epochs, "lr": lr, "gamma": gamma})

 for epoch in range(1, epochs + 1):
 train(model, device, train_loader, optimizer, epoch, log_interval, dry_run)
 check_performance(run, model, device, test_loader, epoch)
 scheduler.step()

 # log confusion matrix
 with torch.no_grad():
 data, target = next(iter(test_loader))
 data, target = data.to(device), target.to(device)
 output = model(data)
 pred = output.max(1, keepdim=True)[1]
 run.log_confusion_matrix(target, pred, "confusion-matrix-test-data")

 return model

In [None]:
with Run(
 experiment_name="pytorch-mnist-decorator", run_name="run-1", sagemaker_session=sm_session
) as run:
 trained_model = perform_train(train_set, test_set)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker-remote-function|pytorch_mnist_sample_notebook|pytorch_mnist.ipynb)
