{ "cells": [ { "cell_type": "markdown", "id": "afbae70b", "metadata": {}, "source": [ "## Training Notebook\n", "\n", "This notebook illustrates training of a simple model to classify digits using the MNIST dataset. This code is used to train the model included with the templates. This is meant to be a starter model to show you how to set up Serverless applications to do inferences. For deeper understanding of how to train a good model for MNIST, we recommend literature from the [MNIST website](http://yann.lecun.com/exdb/mnist/). The dataset is made available under a [Creative Commons Attribution-Share Alike 3.0](https://creativecommons.org/licenses/by-sa/3.0/) license." ] }, { "cell_type": "code", "execution_count": 1, "id": "56d620c5", "metadata": {}, "outputs": [], "source": [ "# Install required dependencies\n", "\n", "! pip install -q torch==1.8.0 torchvision==0.9.0" ] }, { "cell_type": "code", "execution_count": 2, "id": "8f6b4a1b", "metadata": {}, "outputs": [], "source": [ "# Torchvision provides an easy way to import MNIST dataset into DataLoaders\n", "\n", "import torch\n", "import torchvision\n", "from torchvision.transforms import ToTensor\n", "\n", "# mini-batch size when training and testing\n", "mini_batch_size = 64\n", "\n", "train_loader = torch.utils.data.DataLoader(\n", " torchvision.datasets.MNIST('./mnist_data/', train=True, download=True, transform=ToTensor()),\n", " batch_size=mini_batch_size)\n", "\n", "test_loader = torch.utils.data.DataLoader(\n", " torchvision.datasets.MNIST('./mnist_data/', train=False, download=True, transform=ToTensor()),\n", " batch_size=mini_batch_size)\n" ] }, { "cell_type": "markdown", "id": "27c5ae17", "metadata": {}, "source": [ "## PyTorch Model Training\n", "\n", "For this example, we will train a simple CNN classifier using PyTorch to classify the MNIST digits. We will then freeze the model in the TorchScript format. This is same as the starter model file included with the SAM templates." ] }, { "cell_type": "code", "execution_count": 6, "id": "c93daad0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n", "Model(\n", " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (convbn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (convbn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer1): Linear(in_features=800, out_features=100, bias=True)\n", " (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer2): Linear(in_features=100, out_features=100, bias=True)\n", " (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer3): Linear(in_features=100, out_features=100, bias=True)\n", " (bn3): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer4): Linear(in_features=100, out_features=100, bias=True)\n", " (bn4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer5): Linear(in_features=100, out_features=100, bias=True)\n", " (bn5): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (smax): Linear(in_features=100, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "# Use a GPU if set up on this machine\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"Using {} device\".format(device))\n", "\n", "# We'll start with building a model\n", "class Model(nn.Module):\n", " \n", " def __init__(self):\n", " super(Model, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n", " self.convbn1 = nn.BatchNorm2d(32)\n", " \n", " self.conv2 = nn.Conv2d(32, 32, kernel_size=3)\n", " self.convbn2 = nn.BatchNorm2d(32)\n", " \n", " layer_size = 100\n", "\n", " self.layer1 = nn.Linear(800, layer_size)\n", " self.bn1 = nn.BatchNorm1d(layer_size)\n", " \n", " self.layer2 = nn.Linear(layer_size, layer_size)\n", " self.bn2 = nn.BatchNorm1d(layer_size)\n", " \n", " self.layer3 = nn.Linear(layer_size, layer_size)\n", " self.bn3 = nn.BatchNorm1d(layer_size)\n", " \n", " self.layer4 = nn.Linear(layer_size, layer_size)\n", " self.bn4 = nn.BatchNorm1d(layer_size)\n", " \n", " self.layer5 = nn.Linear(layer_size, layer_size)\n", " self.bn5 = nn.BatchNorm1d(layer_size)\n", " \n", " self.smax = nn.Linear(layer_size, 10)\n", "\n", " def forward(self, x):\n", " x = self.convbn1(F.relu(F.max_pool2d(self.conv1(x), 2)))\n", " x = F.dropout2d(x, training=self.training)\n", " \n", " x = self.convbn2(F.relu(F.max_pool2d(self.conv2(x), 2)))\n", " x = F.dropout2d(x, training=self.training)\n", " \n", " x = x.view(-1, 800)\n", " x = F.dropout(self.bn1(F.relu(self.layer1(x))), training=self.training)\n", " x = F.dropout(self.bn2(F.relu(self.layer2(x))), training=self.training)\n", " x = F.dropout(self.bn3(F.relu(self.layer3(x))), training=self.training)\n", " x = F.dropout(self.bn4(F.relu(self.layer4(x))), training=self.training)\n", " x = F.dropout(self.bn5(F.relu(self.layer5(x))), training=self.training)\n", " \n", " return self.smax(x)\n", "\n", "model = Model().to(device)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 7, "id": "9c6d8295", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 0\n", "---------------------------------------------\n", "loss: 2.399322748184204 [0/60000]\n", "loss: 2.5006144046783447 [12800/60000]\n", "loss: 2.528806447982788 [25600/60000]\n", "loss: 2.287709951400757 [38400/60000]\n", "loss: 2.41180419921875 [51200/60000]\n", "Test accuracy: 14.360000000000001%, avg loss: 0.0338871225476265\n", "\n", "Epoch 1\n", "---------------------------------------------\n", "loss: 2.191370964050293 [0/60000]\n", "loss: 0.7450801730155945 [12800/60000]\n", "loss: 0.2356947809457779 [25600/60000]\n", "loss: 0.13707228004932404 [38400/60000]\n", "loss: 0.2939474582672119 [51200/60000]\n", "Test accuracy: 95.89%, avg loss: 0.0021439245976740493\n", "\n", "Epoch 2\n", "---------------------------------------------\n", "loss: 0.14360113441944122 [0/60000]\n", "loss: 0.08505825698375702 [12800/60000]\n", "loss: 0.06298833340406418 [25600/60000]\n", "loss: 0.07103477418422699 [38400/60000]\n", "loss: 0.18312042951583862 [51200/60000]\n", "Test accuracy: 97.7%, avg loss: 0.0011607079559122213\n", "\n", "Epoch 3\n", "---------------------------------------------\n", "loss: 0.08531039953231812 [0/60000]\n", "loss: 0.05882638320326805 [12800/60000]\n", "loss: 0.016713779419660568 [25600/60000]\n", "loss: 0.07784833014011383 [38400/60000]\n", "loss: 0.15998023748397827 [51200/60000]\n", "Test accuracy: 98.02%, avg loss: 0.0009536874826851999\n", "\n", "Epoch 4\n", "---------------------------------------------\n", "loss: 0.044536370784044266 [0/60000]\n", "loss: 0.05791319161653519 [12800/60000]\n", "loss: 0.01448937226086855 [25600/60000]\n", "loss: 0.08377320319414139 [38400/60000]\n", "loss: 0.14489729702472687 [51200/60000]\n", "Test accuracy: 98.22%, avg loss: 0.0009113169919452048\n", "\n", "Epoch 5\n", "---------------------------------------------\n", "loss: 0.027671929448843002 [0/60000]\n", "loss: 0.0939938873052597 [12800/60000]\n", "loss: 0.012829871848225594 [25600/60000]\n", "loss: 0.07140941172838211 [38400/60000]\n", "loss: 0.1233079582452774 [51200/60000]\n", "Test accuracy: 98.45%, avg loss: 0.0007938099530096224\n", "\n", "Epoch 6\n", "---------------------------------------------\n", "loss: 0.012286216951906681 [0/60000]\n", "loss: 0.10050684213638306 [12800/60000]\n", "loss: 0.005056383088231087 [25600/60000]\n", "loss: 0.03812812268733978 [38400/60000]\n", "loss: 0.08250676095485687 [51200/60000]\n", "Test accuracy: 98.58%, avg loss: 0.0007693227126433158\n", "\n", "Epoch 7\n", "---------------------------------------------\n", "loss: 0.0035060422960668802 [0/60000]\n", "loss: 0.037150781601667404 [12800/60000]\n", "loss: 0.012905079871416092 [25600/60000]\n", "loss: 0.00579256284981966 [38400/60000]\n", "loss: 0.07229607552289963 [51200/60000]\n", "Test accuracy: 98.61%, avg loss: 0.0007274648112347222\n", "\n", "Epoch 8\n", "---------------------------------------------\n", "loss: 0.0010487399995326996 [0/60000]\n", "loss: 0.0195145420730114 [12800/60000]\n", "loss: 0.0031410474330186844 [25600/60000]\n", "loss: 0.03495214134454727 [38400/60000]\n", "loss: 0.05204099044203758 [51200/60000]\n", "Test accuracy: 98.79%, avg loss: 0.000661323677877158\n", "\n", "Epoch 9\n", "---------------------------------------------\n", "loss: 0.00041415600571781397 [0/60000]\n", "loss: 0.004038047045469284 [12800/60000]\n", "loss: 0.0016652886988595128 [25600/60000]\n", "loss: 0.004290355369448662 [38400/60000]\n", "loss: 0.046403829008340836 [51200/60000]\n", "Test accuracy: 98.5%, avg loss: 0.0008367217959202435\n", "Done!\n" ] } ], "source": [ "# Define some hand tuned parameters\n", "# (we already defined the batch size above)\n", "\n", "epochs = 10\n", "learning_rate = 10**-4\n", "log_step = 200\n", "\n", "# Define our loss function and optimizer\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "\n", "# Single training epoch loop\n", "def train(train_loader, model, loss_fn, optimizer):\n", " size = len(train_loader.dataset)\n", " \n", " for batch, (X, y) in enumerate(train_loader):\n", " X, y = X.to(device), y.to(device)\n", " \n", " # Forward pass and compute loss\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", " \n", " # Backpropagate loss\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " if batch % log_step == 0:\n", " loss, current = loss.item(), batch * len(X)\n", " print(f'loss: {loss} [{current}/{size}]')\n", "\n", "\n", "def test(test_loader, model):\n", " size = len(test_loader.dataset)\n", " model.eval()\n", " \n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in test_loader:\n", " X, y = X.to(device), y.to(device)\n", " pred = model(X)\n", " \n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " \n", " test_loss /= size\n", " correct /= size\n", " \n", " print(f'Test accuracy: {100*correct}%, avg loss: {test_loss}')\n", "\n", "# Driver loop to start training\n", "for epoch_no in range(epochs):\n", " print(f'\\nEpoch {epoch_no}\\n---------------------------------------------')\n", " \n", " train(train_loader, model, loss_fn, optimizer)\n", " test(test_loader, model)\n", "\n", "print('Done!')" ] }, { "cell_type": "markdown", "id": "ffb79297", "metadata": {}, "source": [ "We will save the model as a [TorchScript](https://pytorch.org/docs/stable/jit.html) file to export it for inferencing. Note that PyTorch offers [more ways](https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=load#saving-loading-model-for-inference) for saving models depending on your use case and execution environment." ] }, { "cell_type": "code", "execution_count": 31, "id": "5eab2a38", "metadata": {}, "outputs": [], "source": [ "# Convert to a TorchScript model optimized for running on CPU\n", "scripted_model = torch.jit.script(model.cpu())\n", "\n", "# Let's sanity check the models give same results using random input\n", "model.eval()\n", "scripted_model.eval()\n", "\n", "for i in range(1000):\n", " X = torch.randn(1, 1, 28, 28)\n", " \n", " pt_ans = torch.argmax(model(X)).item()\n", " ts_ans = torch.argmax(scripted_model(X)).item()\n", " assert pt_ans == ts_ans\n", "\n", "# Freeze the scripted model to include with the template\n", "scripted_model.save('digit_classifier.pt')" ] }, { "cell_type": "code", "execution_count": null, "id": "8b6f7b2a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }