{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/aws/studio-lab-examples/blob/main/computer-vision/kmnist/cv-kminst.ipynb)\n", "\n", "# Train an image classification model with PyTorch \n", "\n", "## Background\n", "\n", "Image classification (or Image recognition) is a subdomain of computer vision in which an algorithm looks at an image and assigns it a tag from a collection of predefined tags or categories that it has been trained on.\n", "\n", "Vision is responsible for 80-85 percent of our perception of the world, and we, as human beings, trivially perform classification daily on whatever data we come across.\n", "\n", "Therefore, emulating a classification task with the help of neural networks is one of the first uses of computer vision that researchers thought about.\n", "\n", "\n", "## Pytorch\n", "\n", "[PyTorch](https://pytorch.org/) is an open-source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab (FAIR). It is free and open-source software released under the Modified BSD license.\n", "\n", "## KMINST dataset\n", "\n", "In this notebook we will use the KMINST dataset\n", "https://pytorch.org/vision/stable/datasets.html#torchvision-datasets\n", " \n", "## Steps in this notebook\n", "\n", "The main steps are:\n", "- install packages\n", "- load dataset, what's train and test loader\n", "- print out 1 image from KMINST dataset, with which library, \n", "- initiate the NN model, what's neural network, furthor reading: mlu-cv course\n", "- train a model with config\n", "- test a model \n", "- print a figure with train and test loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![KMIST dataset picture](https://raw.githubusercontent.com/rois-codh/kmnist/master/images/kmnist_examples.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Install python packages for later use\n", "\n", " - **torch**: for install [PyTorch](https://pytorch.org/). PyTorch is a Python package that provides two high-level features: (1) Tensor computation (like NumPy) with strong GPU acceleration (2) Deep neural networks built on a tape-based autograd system\n", "\n", " - torchvision: \n", " - torchsummary: \n", " - matplotlib: \n", " \n", "Before running this notebook we need to install the required Python packages by opening a terminal window (click on File/New/Terminal) going to the location where the env_cv.yml file is located and launching the following two commands:\n", "- conda env create -f env_cv.yml\n", "- conda activate cv\n", "\n", "After that we also need to switch the notebook kernel (click on Kernel/Change Kernel...) to **.conda-cv:Python** " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uncomment these lines if you do not want to use the env_cv.yml file\n", "#%pip install torch \n", "#%pip install torchvision\n", "#%pip install torchsummary\n", "#%pip install matplotlib\n", "#%pip install ipywidgets" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "if torch.cuda.is_available(): \n", " device = \"cuda:0\" \n", "else: \n", " device = \"cpu\" " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "results_dir='results'\n", "if not os.path.isdir(results_dir):\n", " !mkdir -p {results_dir}" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "## Create dataloader, in PyTorch, we feed the trainer data with use of dataloader\n", "## We create dataloader with dataset from torchvision, \n", "## and we dont have to download it seperately, all automatically done\n", "\n", "# Define batch size, batch size is how much data you feed for training in one iteration\n", "batch_size_train = 64 # We use a small batch size here for training\n", "batch_size_test = 1024 #\n", "\n", "# define how image transformed\n", "image_transform = torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(\n", " (0.1307,), (0.3081,))\n", " ])\n", "#image datasets\n", "train_dataset = torchvision.datasets.KMNIST('dataset/', \n", " train=True, \n", " download=True,\n", " transform=image_transform)\n", "test_dataset = torchvision.datasets.KMNIST('dataset/', \n", " train=False, \n", " download=True,\n", " transform=image_transform)\n", "#data loaders\n", "train_loader = torch.utils.data.DataLoader(train_dataset,\n", " batch_size=batch_size_train, \n", " shuffle=True)\n", "test_loader = torch.utils.data.DataLoader(test_dataset,\n", " batch_size=batch_size_test, \n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label: tensor(4)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQj0lEQVR4nO3df4xU5b3H8c+XXyKgCciVoGxsRQhgzd0qIdeU3KiEsiUaaVADxOqNpFsTiTbWRPyRiBIF9ZYGzJW4XkmpUkxDJWI0Wi5qjCH+QAFdBQXN+mOzShuigggLu9/7xx7MKnues8w584N93q9ks7PnO2fmy7CfPTPzzHkec3cB6Pv6VbsBAJVB2IFIEHYgEoQdiARhByIxoJJ3ZmZRvvV/1llnBetmFqxnjZi0tbWl1trb24P7ou9x9x5/oSzP0JuZNUhaLqm/pP9196UZ1++TYR8wIPw385FHHgnWBw4cGKwfOnQoWF+8eHFq7dNPPw3ui74nLewlP403s/6S/kfSryRNkjTXzCaVensAyivPa/Ypkna7+8fu3i7pSUmXF9MWgKLlCfuZkj7r9vPnybYfMLNGM9tiZlty3BeAnMr+Bp27N0lqkvrua3bgRJDnyN4qqa7bz2OSbQBqUJ6wvylpnJn91MwGSZojaUMxbQEoWslP4939iJktkPSCuobeVrn7e4V1dgLp7OwM1ltbw0947rjjjmA9axz+2WefTa0x9Iajcr1md/fnJD1XUC8AyoiPywKRIOxAJAg7EAnCDkSCsAORIOxAJCp6PntflTXOfs899wTrw4cPD9YXLFgQrO/fvz+11q9f+O95Vu/oOziyA5Eg7EAkCDsQCcIORIKwA5Eg7EAkcs0ue9x3xkw1PZo0KTxP59atW4P17du3p9ZOOeWU4L7Lli0L1l944YVgnVNoa0/hs8sCOLEQdiAShB2IBGEHIkHYgUgQdiAShB2IRJ8ZZx8zZkyw/tVXXwXrodNEy+3iiy8O1l988cUKdXKsO++8M1i/9957K9TJsU466aRgfeTIkam1rOm9T2SMswORI+xAJAg7EAnCDkSCsAORIOxAJAg7EIk+M5X0BRdcEKxPnDgxWF+6dGmR7fxA1nTOQ4YMKdt9Z+no6AjW16xZU6FOjlVfXx+sr1y5MlgfP358am3JkiXBfVesWBGst7e3B+u1KFfYzaxF0j5JHZKOuPvkIpoCULwijuwXu/u/CrgdAGXEa3YgEnnD7pL+YWZvmVljT1cws0Yz22JmW3LeF4Ac8j6Nn+rurWZ2uqSNZrbT3V/pfgV3b5LUJDHhJFBNuY7s7t6afN8jab2kKUU0BaB4JYfdzIaa2SlHL0v6paTmohoDUKw8T+NHSVpvZkdv56/u/nwhXZUga271hx9+OFjfvHlzsN6/f//UWkNDQ3DfsWPHBuunnnpqsJ7l4MGDqbWsc77XrVsXrLe0tJTSUq/uv7Gxx7d5vnfdddcF621tbcH64cOHU2uLFy8O7nvGGWcE67fddluwfujQoWC9GkoOu7t/LOnfC+wFQBkx9AZEgrADkSDsQCQIOxAJwg5Eos+c4pq1dPCrr74arD//fOmjhieffHLJ+xZh8ODBqbXQsJwkPfTQQ7nue+DAgcH6woULU2tXXHFFcN+5c+cG683N4Y91hJarnjp1anDfVatWBetZ/+6bbropWO/s7AzWy4EjOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkegzSzZnmTBhQrD+1FNPBetZU1FXU2hM+NFHHw3u+/rrrwfrySnMqW6++eZgfcqU9PlMrr/++uC+e/fuDdbzyJre+7777gvWs/7dM2bMCNZfeumlYD0PlmwGIkfYgUgQdiAShB2IBGEHIkHYgUgQdiASfeZ89iw7d+4M1mfNmhWsP/7446m1c845J7jvoEGDgvVhw4YF61mfhXjiiSdSa6+99lpw3yzz588P1uvq6oL1efPmpdaOHDlSUk+9FfqMwCWXXBLcN2sa66zz2c8999xgvZzj7Gk4sgORIOxAJAg7EAnCDkSCsAORIOxAJAg7EIloxtmzfPjhh8F6aFw2a8w1a/nfrHPKhw4dGqwfOHAgWA+57LLLgvVJkyYF67feemuwnmcsfcCA8K/n9OnTg/XQvPRXX311cN+sz0Zkzfve0dERrFdD5pHdzFaZ2R4za+62bYSZbTSzXcn34eVtE0BevXka/2dJDT/atlDSJncfJ2lT8jOAGpYZdnd/RdKP5we6XNLq5PJqSbOKbQtA0Up9zT7K3duSy19IGpV2RTNrlNRY4v0AKEjuN+jc3UMTSbp7k6QmqboTTgKxK3Xo7UszGy1Jyfc9xbUEoBxKDfsGSdcml6+V9HQx7QAol8yn8Wa2VtJFkkaa2eeS7pK0VNLfzGy+pE8kXVXOJmvBt99+W/K+33zzTbCetYZ61jj7kCFDUmtLly4N7hua112SGhp+PBDzQ+3t7cF6aH72rLnVly9fHqyPHTs2WA+t37558+bgvlmfD8gaZ9+1a1ewXg2ZYXf3uSmlaQX3AqCM+LgsEAnCDkSCsAORIOxAJAg7EAlOca0BWafIZi2bfNddd6XW6uvrg/vOnj07WM8aWstyww03pNYefPDB4L4vv/xysJ413XNoGu1yT2NdiziyA5Eg7EAkCDsQCcIORIKwA5Eg7EAkCDsQCctaDrjQO4t0ppqsJZlbWlqC9dNOO63k+w4t5yxJ11xzTbCe9fsxZsyYYH3r1q2pte3btwf3vfTSS4P1rFODY+XuPX4wgyM7EAnCDkSCsAORIOxAJAg7EAnCDkSCsAOR4Hz2CshaFnnEiBG5bv/rr79OrS1ZsiS4b9Y4ev/+/YP1lStXBuuhf9stt9wS3Jdx9GJxZAciQdiBSBB2IBKEHYgEYQciQdiBSBB2IBKMs1fAtGnhBW+zxrqz5o1fs2ZNam3Hjh3BfbNceeWVwfrMmTOD9dbW1tTa7t27S+oJpck8spvZKjPbY2bN3bYtMrNWM9uWfIX/xwFUXW+exv9ZUkMP2//k7vXJ13PFtgWgaJlhd/dXJO2tQC8AyijPG3QLzOyd5Gn+8LQrmVmjmW0xsy057gtATqWGfaWksZLqJbVJ+mPaFd29yd0nu/vkEu8LQAFKCru7f+nuHe7eKelRSVOKbQtA0UoKu5mN7vbjryU1p10XQG3IHGc3s7WSLpI00sw+l3SXpIvMrF6SS2qR9LvytVj7Bg0aFKzPmDEj1+13dHQE6ytWrEit5V0XYN68ecF6v37h40Vzc/pxYP/+/SX1hNJkht3d5/aw+bEy9AKgjPi4LBAJwg5EgrADkSDsQCQIOxAJTnEtwIQJE4L1rGWNszzzzDPBep5TRbN6mz59esm3LUlvvPFGrv1RHI7sQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgnH2AowfPz5YzzrN9LvvvgvWFy1aFKxnnQIbMnHixGB98ODBwXp7e3uw/uSTTx53TygPjuxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSCcfYCbNiwIVjPmkr6s88+C9Z37tx53D311oUXXphr/3379gXrH330Ua7bR3E4sgORIOxAJAg7EAnCDkSCsAORIOxAJAg7EAnG2QuQdU73xo0bK9TJ8auvr8+1/44dO4L1w4cP57p9FCfzyG5mdWb2kpm9b2bvmdlNyfYRZrbRzHYl34eXv10AperN0/gjkv7g7pMk/YekG8xskqSFkja5+zhJm5KfAdSozLC7e5u7v51c3idph6QzJV0uaXVytdWSZpWpRwAFOK7X7Gb2E0k/l/S6pFHu3paUvpA0KmWfRkmNOXoEUIBevxtvZsMk/V3S7939m+4175pRscdZFd29yd0nu/vkXJ0CyKVXYTezgeoK+hp3fyrZ/KWZjU7qoyXtKU+LAIqQ+TTezEzSY5J2uPuybqUNkq6VtDT5/nRZOkQuw4eHB0nGjRuX6/ZbWlpy7Y/K6c1r9l9I+o2kd81sW7LtdnWF/G9mNl/SJ5KuKkuHAAqRGXZ3f1WSpZSnFdsOgHLh47JAJAg7EAnCDkSCsAORIOxAJCxrOeFC78yscnfWh3R91CHdnDlzUmv3339/cN+6urqSejpqz57wZ6nOO++8kvdFady9x18YjuxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSCqaRrQNY4+o033hisL1u2LLXWr195/56ffvrpwfq0aeknRq5du7bodhDAkR2IBGEHIkHYgUgQdiAShB2IBGEHIkHYgUgwzl4DZs+eHaw/8MADwXqesfR169YF62effXawfv755wfrDQ0NqTXG2SuLIzsQCcIORIKwA5Eg7EAkCDsQCcIORIKwA5HozfrsdZL+ImmUJJfU5O7LzWyRpN9K+mdy1dvd/blyNdqXdXR0BOsHDhwI1g8ePJhayxqjX758ebB+9913B+tZ4+xTp05NrWV9PqCzszNYx/HpzYdqjkj6g7u/bWanSHrLzDYmtT+5+3+Xrz0ARenN+uxtktqSy/vMbIekM8vdGIBiHddrdjP7iaSfS3o92bTAzN4xs1VmNjxln0Yz22JmW/K1CiCPXofdzIZJ+ruk37v7N5JWShorqV5dR/4/9rSfuze5+2R3n5y/XQCl6lXYzWyguoK+xt2fkiR3/9LdO9y9U9KjkqaUr00AeWWG3bqmPn1M0g53X9Zt++huV/u1pObi2wNQlN68G/8LSb+R9K6ZbUu23S5prpnVq2s4rkXS78rQXxTWr18frL///vvB+oAB6f+NWftmLdnd3Bz+G541LDhs2LDUWtYU2ihWb96Nf1VST/8rjKkDJxA+QQdEgrADkSDsQCQIOxAJwg5EgrADkWAq6RPABx98ULX7bmtrC9aHDBkSrD/99NOpNU5hrSyO7EAkCDsQCcIORIKwA5Eg7EAkCDsQCcIORMKyzmcu9M7M/inpk26bRkr6V8UaOD612lut9iXRW6mK7O0sd/+3ngoVDfsxd262pVbnpqvV3mq1L4neSlWp3ngaD0SCsAORqHbYm6p8/yG12lut9iXRW6kq0ltVX7MDqJxqH9kBVAhhByJRlbCbWYOZfWBmu81sYTV6SGNmLWb2rpltq/b6dMkaenvMrLnbthFmttHMdiXfe1xjr0q9LTKz1uSx22ZmM6vUW52ZvWRm75vZe2Z2U7K9qo9doK+KPG4Vf81uZv0lfShpuqTPJb0paa67h1czqBAza5E02d2r/gEMM/tPSfsl/cXdf5Zse0DSXndfmvyhHO7ut9ZIb4sk7a/2Mt7JakWjuy8zLmmWpP9SFR+7QF9XqQKPWzWO7FMk7Xb3j929XdKTki6vQh81z91fkbT3R5svl7Q6ubxaXb8sFZfSW01w9zZ3fzu5vE/S0WXGq/rYBfqqiGqE/UxJn3X7+XPV1nrvLukfZvaWmTVWu5kejHL3o3NFfSFpVDWb6UHmMt6V9KNlxmvmsStl+fO8eIPuWFPd/XxJv5J0Q/J0tSZ512uwWho77dUy3pXSwzLj36vmY1fq8ud5VSPsrZLquv08JtlWE9y9Nfm+R9J61d5S1F8eXUE3+b6nyv18r5aW8e5pmXHVwGNXzeXPqxH2NyWNM7OfmtkgSXMkbahCH8cws6HJGycys6GSfqnaW4p6g6Rrk8vXSkqfvrXCamUZ77RlxlXlx67qy5+7e8W/JM1U1zvyH0m6oxo9pPR1tqTtydd71e5N0lp1Pa07rK73NuZLOk3SJkm7JP2fpBE11Nvjkt6V9I66gjW6Sr1NVddT9HckbUu+Zlb7sQv0VZHHjY/LApHgDTogEoQdiARhByJB2IFIEHYgEoQdiARhByLx/6gFDuZXeyLuAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# import library\n", "\n", "import matplotlib.pyplot as plt\n", "# We can check the dataloader\n", "_, (example_datas, labels) = next(enumerate(test_loader))\n", "sample = example_datas[0][0]\n", "# show the data\n", "plt.imshow(sample, cmap='gray', interpolation='none')\n", "print(\"Label: \"+ str(labels[0]))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "## Now we can start to build our CNN model\n", "## We first import the pytorch nn module and optimizer\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "## Then define the model class\n", "class CNN(nn.Module):\n", " def __init__(self):\n", " super(CNN, self).__init__()\n", " #input channel 1, output channel 10\n", " self.conv1 = nn.Conv2d(1, 10, kernel_size=5, stride=1)\n", " #input channel 10, output channel 20\n", " self.conv2 = nn.Conv2d(10, 20, kernel_size=5, stride=1)\n", " #dropout layer\n", " self.conv2_drop = nn.Dropout2d()\n", " #fully connected layer\n", " self.fc1 = nn.Linear(320, 50)\n", " self.fc2 = nn.Linear(50, 10)\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.max_pool2d(x, 2)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = self.conv2_drop(x)\n", " x = F.max_pool2d(x, 2)\n", " x = F.relu(x)\n", " x = x.view(-1, 320)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = F.dropout(x)\n", " x = self.fc2(x)\n", " return F.log_softmax(x)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "## create model and optimizer\n", "learning_rate = 0.01\n", "momentum = 0.5\n", "model = CNN().to(device) #using cpu here\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate,\n", " momentum=momentum)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 10, 24, 24] 260\n", " Conv2d-2 [-1, 20, 8, 8] 5,020\n", " Dropout2d-3 [-1, 20, 8, 8] 0\n", " Linear-4 [-1, 50] 16,050\n", " Linear-5 [-1, 10] 510\n", "================================================================\n", "Total params: 21,840\n", "Trainable params: 21,840\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.06\n", "Params size (MB): 0.08\n", "Estimated Total Size (MB): 0.15\n", "----------------------------------------------------------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_2236/4020281764.py:32: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", " return F.log_softmax(x)\n" ] } ], "source": [ "\n", "from torchsummary import summary\n", "summary(model, (1, 28, 28))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "n_epochs = 3\n", "log_interval = 100\n", "train_losses = []\n", "train_counter = []\n", "test_losses = []\n", "test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def train(epoch):\n", " model.train()\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " optimizer.zero_grad()\n", " if torch.cuda.is_available(): \n", " output = model(data.cuda())\n", " loss = F.nll_loss(output, target.cuda())\n", " else:\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " if batch_idx % log_interval == 0:\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))\n", " train_losses.append(loss.item())\n", " train_counter.append(\n", " (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))\n", " torch.save(model.state_dict(), './results/model.pth')\n", " torch.save(optimizer.state_dict(), './results/optimizer.pth')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def test():\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " with torch.no_grad():\n", " for data, target in test_loader:\n", " if torch.cuda.is_available(): \n", " output = model(data.cuda())\n", " test_loss += F.nll_loss(output, target.cuda(), size_average=False).item()\n", " pred = output.data.max(1, keepdim=True)[1]\n", " correct += pred.eq(target.cuda().data.view_as(pred)).sum()\n", " else:\n", " output = model(data)\n", " test_loss += F.nll_loss(output, target, size_average=False).item()\n", " pred = output.data.max(1, keepdim=True)[1]\n", " correct += pred.eq(target.data.view_as(pred)).sum() \n", " test_loss /= len(test_loader.dataset)\n", " test_losses.append(test_loss)\n", " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", " test_loss, correct, len(test_loader.dataset),\n", " 100. * correct / len(test_loader.dataset)))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_2236/4020281764.py:32: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", " return F.log_softmax(x)\n", "/home/studio-lab-user/.conda/envs/cv/lib/python3.9/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", " warnings.warn(warning.format(ret))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Test set: Avg. loss: 2.3128, Accuracy: 1116/10000 (11%)\n", "\n", "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.340610\n", "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 2.001754\n", "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 1.643779\n", "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 1.515469\n", "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 1.375710\n", "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 1.045141\n", "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 1.002589\n", "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 1.386037\n", "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 1.052345\n", "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 1.023880\n", "\n", "Test set: Avg. loss: 1.2883, Accuracy: 5843/10000 (58%)\n", "\n", "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.777202\n", "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.865407\n", "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.909203\n", "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.996623\n", "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.718363\n", "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.992785\n", "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.966162\n", "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.730128\n", "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.690188\n", "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.975658\n", "\n", "Test set: Avg. loss: 1.0804, Accuracy: 6442/10000 (64%)\n", "\n", "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.664746\n", "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.724125\n", "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.733852\n", "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.840492\n", "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.726298\n", "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.741529\n", "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.630413\n", "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.801318\n", "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.753129\n", "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.709658\n", "\n", "Test set: Avg. loss: 0.9515, Accuracy: 6971/10000 (70%)\n", "\n" ] } ], "source": [ "test()\n", "for epoch in range(1, n_epochs + 1):\n", " train(epoch)\n", " test()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'negative log likelihood loss')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAr3ElEQVR4nO3deXxU1fnH8c/DIohgqRIrghjoD9eIIFGUuuCOikWtVSwuVP3x06qI+4LWpbXVLtatitSFVqOiFZeKa61UrAoNiCzigogQREWqiFCU5fn9cU5gjJnJTTKTSTLf9+s15t5zt2fuyDxzzz33HHN3RESkcLXIdwAiIpJfSgQiIgVOiUBEpMApEYiIFDglAhGRAtcq3wHUVqdOnby4uDjfYYiINClTp0791N2LqlvW5BJBcXEx5eXl+Q5DRKRJMbMP0i1T1ZCISIFTIhARKXBKBCIiBa7J3SMQkeZl9erVVFRUsGrVqnyH0iy0bduWrl270rp168TbKBGISF5VVFTQoUMHiouLMbN8h9OkuTtLly6loqKC7t27J96uMKqGysqguBhatAh/y8ryHZGIRKtWrWLzzTdXEsgCM2PzzTev9dVV878iKCuD4cNh5cow/8EHYR5g6ND8xSUi6ykJZE9dzmXzvyIYNWpDEqi0cmUoFxGRAkgECxbUrlxECsrSpUvp3bs3vXv3Zsstt6RLly7r57/++uuM25aXlzNixIhaHa+4uJhPP/20PiFnXfOvGurWLVQHVVcuIgVv8803Z/r06QBcddVVtG/fngsuuGD98jVr1tCqVfVflaWlpZSWljZEmDnV/K8Irr0W2rX7Zlm7dqFcRKQaw4YN4/TTT6dfv35cdNFFTJkyhT333JM+ffrQv39/3n77bQAmTpzIoEGDgJBETjnlFAYMGECPHj24+eabEx9v/vz57L///vTq1YsDDjiABbHG4uGHH6akpIRddtmFffbZB4DZs2ez++6707t3b3r16sW7775b7/fb/K8IKm8IjxoVqoO6dQtJQDeKRRqdkSMh/jjPmt694cYba79dRUUFr7zyCi1btuSLL75g0qRJtGrVir///e9cdtllPPLII9/a5q233uLFF19k+fLlbLfddpxxxhmJ2vOfffbZnHzyyZx88sncfffdjBgxgscee4xrrrmGZ599li5duvD5558DMHr0aM455xyGDh3K119/zdq1a2v/5qpo/okAwpe+vvhFpBZ+/OMf07JlSwCWLVvGySefzLvvvouZsXr16mq3Ofzww2nTpg1t2rRhiy224OOPP6Zr1641HuvVV19l/PjxAJx44olcdNFFAPzgBz9g2LBhHHvssRx99NEA7Lnnnlx77bVUVFRw9NFH07Nnz3q/18JIBCLSJNTll3uubLLJJuunr7jiCvbbbz8effRR5s+fz4ABA6rdpk2bNuunW7ZsyZo1a+oVw+jRo5k8eTITJkygb9++TJ06lZ/85Cf069ePCRMmcNhhh3HHHXew//771+s4zf8egYhIPS1btowuXboAMHbs2Kzvv3///jz44IMAlJWVsffeewPw3nvv0a9fP6655hqKiopYuHAh8+bNo0ePHowYMYLBgwczY8aMeh9fiUBEpAYXXXQRl156KX369Kn3r3yAXr160bVrV7p27cp5553HLbfcwj333EOvXr249957uemmmwC48MIL2XnnnSkpKaF///7ssssuPPTQQ5SUlNC7d29mzZrFSSedVO94zN3rvZOGVFpa6hqYRqT5mDNnDjvssEO+w2hWqjunZjbV3att66orAhGRAqdEICJS4HKWCMxsazN70czeNLPZZnZONesMNbMZZjbTzF4xs11yFY+IiFQvl81H1wDnu/s0M+sATDWz5939zZR13gf2dffPzOxQYAzQL4cxiYhIFTlLBO6+GFgcp5eb2RygC/BmyjqvpGzyGlDzkxciIpJVDXKPwMyKgT7A5AyrnQo8nWb74WZWbmblS5YsyUGEIiKFK+dPFptZe+ARYKS7f5Fmnf0IiWCv6pa7+xhCtRGlpaVNq72riDRqS5cu5YADDgDgo48+omXLlhQVFQEwZcoUNtpoo4zbT5w4kY022oj+/ft/a9nYsWMpLy/n1ltvzX7gWZTTRGBmrQlJoMzdx6dZpxdwJ3Couy/NZTwiIlXV1A11TSZOnEj79u2rTQRNRS5bDRlwFzDH3W9Is043YDxworu/k6tYRKQZaYAxyKdOncq+++5L3759OeSQQ1i8eDEAN998MzvuuCO9evViyJAhzJ8/n9GjR/OHP/yB3r17M2nSpET7v+GGGygpKaGkpIQbYwdLK1as4PDDD2eXXXahpKSEcePGAXDJJZesP2ZtElRt5PKK4AfAicBMM5seyy4DugG4+2jg58DmwG1xnM016Z58ExFpiDHI3Z2zzz6bxx9/nKKiIsaNG8eoUaO4++67ue6663j//fdp06YNn3/+OR07duT000+v1VXE1KlTueeee5g8eTLuTr9+/dh3332ZN28eW221FRMmTABC/0ZLly7l0Ucf5a233sLM1ndFnW25bDX0MpBxFGV3Pw04LVcxiEgzk2kM8iwlgq+++opZs2Zx0EEHAbB27Vo6d+4MhD6Chg4dypFHHsmRRx5Zp/2//PLLHHXUUet7Nz366KOZNGkSAwcO5Pzzz+fiiy9m0KBB7L333qxZs4a2bdty6qmnMmjQoPWD4GSbniwWkaajAcYgd3d22mknpk+fzvTp05k5cybPPfccABMmTODMM89k2rRp7LbbblnpgK7Stttuy7Rp09h55525/PLLueaaa2jVqhVTpkzhmGOO4cknn2TgwIFZO14qJQIRaTrSjTWexTHI27Rpw5IlS3j11VcBWL16NbNnz2bdunUsXLiQ/fbbj+uvv55ly5bx5Zdf0qFDB5YvX554/3vvvTePPfYYK1euZMWKFTz66KPsvffefPjhh7Rr144TTjiBCy+8kGnTpvHll1+ybNkyDjvsMP7whz/wxhtvZO19ptLANCLSdFx77TfvEUDWxyBv0aIFf/3rXxkxYgTLli1jzZo1jBw5km233ZYTTjiBZcuW4e6MGDGCjh07csQRR3DMMcfw+OOPc8stt6wfS6DS2LFjeeyxx9bPv/baawwbNozdd98dgNNOO40+ffrw7LPPcuGFF9KiRQtat27N7bffzvLlyxk8eDCrVq3C3bnhhmrb3dSbuqEWkbyqdTfUZWUag7wGte2GWlcEItK0aAzyrNM9AhGRAqdEICJ519SqqBuzupxLJQIRyau2bduydOlSJYMscHeWLl1K27Zta7Wd7hGISF517dqViooK1LNwdrRt25auXWvXo78SgYjkVevWrenevXu+wyhoqhoSESlwSgQiIgVOiUBEpMApEYiIFLgaE4GZ/cbMNjWz1mb2gpktMbMTGiI4ERHJvSRXBAfHsYYHAfOB/wEurGkjM9vazF40szfNbLaZnVPNOmZmN5vZXDObYWa71vYNiIhI/SRpPlq5zuHAw+6+LI4mVpM1wPnuPs3MOgBTzex5d38zZZ1DgZ7x1Q+4Pf4VEZEGkuSK4EkzewvoC7xgZkXAqpo2cvfF7j4tTi8H5gBdqqw2GPiLB68BHc2sc63egYiI1EuNicDdLwH6A6XuvhpYQfgCT8zMioE+wOQqi7oAC1PmK/h2ssDMhptZuZmV6+lDEZHsSnKz+MfAandfa2aXA/cBWyU9gJm1Bx4BRsZ7DbXm7mPcvdTdS4uKiuqyCxERSSNJ1dAV7r7czPYCDgTuItTl18jMWhOSQJm7j69mlUXA1inzXWOZiIg0kCSJYG38ezgwxt0nABvVtJGFO8p3AXPcPd34ak8AJ8XWQ3sAy9x9cYKYREQkS5K0GlpkZncABwHXm1kbkiWQHwAnAjPNbHosuwzoBuDuo4GngMOAucBK4Ke1il5EROotSSI4FhgI/M7dP4+temp8jsDdXwYytjP10AH5mUkCFRGR3EjSamgl8B5wiJmdBWzh7s/lPDIREWkQSVoNnQOUAVvE131mdnauAxMRkYaRpGroVKCfu68AMLPrgVeBW3IZmIiINIwkN32NDS2HiNOJ+pgQEZHGL8kVwT3AZDN7NM4fSWgWKiIizUCNicDdbzCzicBesein7v56TqMSEZEGkzYRmNlmKbPz42v9Mnf/T+7CEhGRhpLpimAq4Gy4H+Dxr8XpHjmMS0REGkjaRODu3RsyEBERyQ+NWSwiUuCUCERECpwSgYhIgUvaauhb1GpIRKR5SNpqqBvwWZzuCCwAdDNZRKQZSFs15O7d3b0H8HfgCHfv5O6bA4MA9T4qItJMJLlHsIe7P1U54+5PEwazFxGRZiBJIvjQzC43s+L4GgV8WNNGZna3mX1iZrPSLP+Omf3NzN4ws9lmptHJRETyIEkiOB4oAh6Nry1iWU3GEkY2S+dM4E133wUYAPzezGocC1lERLIrSadz/wHOMbMOYda/TLJjd3/JzIozrQJ0iIPctwf+A6xJsm8REcmeJCOU7WxmrwOzgNlmNtXMSrJw7FuBHQjVTDOBc9x9XZoYhptZuZmVL1myJAuHFhGRSkmqhu4AznP3bdx9G+B8YEwWjn0IMB3YCugN3Gpmm1a3oruPcfdSdy8tKirKwqFFRKRSkkSwibu/WDnj7hOBTbJw7J8C4z2YC7wPbJ+F/YqISC0kSQTzzOyKlFZDlwPzsnDsBcABAGb2PWC7LO1XRERqIclQlacAVwPj4/ykWJaRmT1AaA3UycwqgCuB1gDuPhr4BTDWzGYSnli+2N0/re0bEBGR+knSaugzYEQdWg1lbGLq7h8CByeKUkREciafrYZERKQRyGerIRERaQTy2WpIREQagSQ3i+eZ2RXAvXH+BNS6R0Sk2UhyRXAKoa+h8fFVRIJWQyIi0jQkbjXUALGIiEge1JgIzGxb4AKgOHV9d98/d2GJiEhDSXKP4GFgNHAnsDa34YiISENLkgjWuPvtOY9ERETyIm0iMLPN4uTfzOxnhEFpvqpcHscpEBGRJi7TFcFUwuAxFucvTFnmQI9cBSUiIg0nbSJw9+4NGYiIiORHpqqh/d39H2Z2dHXL3X18deUiItK0ZKoa2hf4B3BENcucDd1Si4hIE5apaujK+PenDReOiIg0tExVQ+dl2tDdb8i03MzuBgYBn7h7td1Wm9kA4EbCgDWfuvu+mcMVEZFsy1Q11KGe+x4L3Ar8pbqFZtYRuA0Y6O4LzGyLeh5PRETqIFPV0NX12bG7v2RmxRlW+Qlh8PoFcf1P6nM8ERGpmyQjlG1rZi+Y2aw43ysOYF9f2wLfNbOJcdSzkzLEMNzMys2sfMmSJVk4tIiIVErSDfWfgEuB1QDuPgMYkoVjtwL6AocDhwBXxA7uvsXdx7h7qbuXFhUVZeHQIiJSKUlfQ+3cfYqZpZatycKxK4Cl7r4CWGFmLwG7AO9kYd8iIpJQkiuCT83s+4RnBzCzY4DFWTj248BeZtbKzNoB/YA5WdiviIjUQpIrgjMJg9Vvb2aLgPeBoTVtZGYPAAOATmZWAVxJaCaKu4929zlm9gwwA1gH3Onus+r0LkREpM6SJILvuvuBZrYJ0MLdl5vZIOCDTBu5+/E17djdfwv8NlmoIiKSC4luFptZibuviElgCHBFrgMTEZGGkeSK4Bjgr2b2E2Bv4CTg4JxGJSIiDSbJ4PXz4lXAY8AC4GB3/2+uAxMRkYaRqa+hmcSWQtFmQEtgspnh7r1yHZyIiORepiuCQQ0WhYiI5E2mRPCZu3+RMnaxiIg0Q5kSwf2Eq4KqYxeDxiwWEWk2MvU+Oij+1djFIiLNWKabxbtm2tDdp2U/HBERaWiZqoZ+n2GZA/tnORYREcmDTFVD+zVkICIikh9JupgQEZFmTIlARKTAKRGIiBS4GvsaStN6aBnwgbtnY6QyERHJoyRXBLcBrxEGp/kT8CrwMPC2maXthdTM7jazTyoHvc+w3m5mtiaOfCYiIg0sSSL4EOgTB4/vC/QB5gEHAb/JsN1YYGCmHZtZS+B64LlE0YqISNYlSQTbuvvsyhl3fxPY3t3nZdrI3V8C/lPDvs8GHgE+SRCHiIjkQJKBaWab2e3Ag3H+OOBNM2sDrK7rgc2sC3AUsB+wWw3rDgeGA3Tr1q2uhxQRkWokuSIYBswFRsbXvFi2mvAlXlc3Ahe7+7qaVnT3MbFqqrSoqKgehxQRkaqSjFD2XzO7hVCP78Db7l55JfBlPY5dCjxoZgCdgMPMbI27P1aPfYqISC0laT46APgzMJ/QFfXWZnZyvAdQZ6m9mprZWOBJJQERkYaX5B7B7wnjFL8NYGbbAg8AfTNtZGYPAAOATmZWAVwJtAZw99H1iFlERLIoSSJoXZkEANz9HTNrXdNG7n580iDcfVjSdUVEJLuSJIJyM7sTuC/ODwXKcxeSiIg0pCSJ4AzgTGBEnJ9EeNpYRESagSSthr4CbogvERFpZjINVTmT0Fy0Wu7eKycRiYhIg8p0RTCowaIQEZG8Sftksbt/kOnVkEFKM1VWBsXF0KJF+FtWlu+IRApSkpvFItlXVgbDh8PKlWH+gw/CPMDQofmLS6QAaYQyyY9RozYkgUorV4ZyEWlQiRKBmW1sZtvlOhgpIAsW1K5cRHKmxkRgZkcA04Fn4nxvM3six3FJc5euO3F1My7S4JJcEVwF7A58DuDu04Hu6VcXSeDaa6Fdu2+WtWsXykWkQSVJBKvdfVmVsrTPF4gkMnQojBkD22wDZuHvmDG6USySB0lHKPsJ0NLMehK6mnglt2FJQRg6VF/8Io1AkiuCs4GdgK+A+4FlhJHKRESkGUhyRbC9u48C1K5PRKQZSnJF8Hszm2NmvzCzkpxHJCIiDarGRODu+xEGqV8C3GFmM83s8pq2M7O7zewTM5uVZvlQM5sR9/eKme1S6+hFRKTeEj1Q5u4fufvNwOmEZwp+nmCzscDADMvfB/Z1952BXwBjksQiIiLZleSBsh3M7KrYLfUthBZDXWvaLg5u/58My19x98/i7GtJ9ikiItmX5Gbx3cA44BB3/zBHcZwKPJ1uoZkNB4YDdNOTpyIiWZVkhLI9cxmAme1HSAR7ZYhhDLHqqLS0VA+ziYhkUaYRyh5y92OrGanMAM/GCGVm1gu4EzjU3ZfWd38iIlJ7ma4Izol/czJSmZl1A8YDJ7r7O7k4hoiI1CxtInD3xXHyZ+5+ceoyM7seuPjbW31jnQeAAUAnM6sArgRax32PJrQ82hy4zcwA1rh7ad3ehoiI1JW5Z65yN7Np7r5rlbIZ+Rq8vrS01MvLy/NxaBGRJsvMpqb7sZ3pHsEZwM+AHmY2I2VRB+Bf2Q1RRETyJdM9gvsJTTp/DVySUr7c3dM+HyAiIk1LpnsEywg9jR4PYGZbAG2B9mbW3t01pqCISDOQaKhKM3uX0CXEP4H5ZHj4S0REmpYkfQ39EtgDeMfduwMHELqEEJFmZs0aePJJuPlm+I8qgAtG0qEqlwItzKyFu78IqJmnSDMydy5cdhl0K1rJEUfAOedAcaflXHHkTCWEApAkEXxuZu2Bl4AyM7sJWJHbsEQk11auhHvvhQEDoGdPuP66dfT9YiLjOYpp9GGgP80vH9+Z4i6rueIKXSE0Z0kSwWDgv8C5wDPAe8ARuQxKRHLDHcrL4YwzoHNnOOkkqKiAa6+FBVvtyd/WHc5RPEYfpvMQxzGTEga2eJZf/hKKi1FCaKZqfKCssdEDZSK1t3QplJXBXXfBjBnQti0ccwyceirssw+0aEH4T3XfB2bMmrGOa66Bhx+GDh1gxAg47zzYbLMGfytSR5keKEvSami5mX1R5bXQzB41sx7ZD1dEsmHdOnj+eRgyBLbaKtT7t24Nt90GixdvqBZqUfktkK6L927dKCmBhx6CmTNh4MBwBVFcDJdfriuE5iBJ1dCNwIVAF8LgMRcQHjZ7kDBWgYg0IgsWwNVXQ48ecPDB8Nxz8H//B9Onb6gW6tixmg2vvRbatftmWbt2oTxKTQiHHgq/+pUSQrPg7hlfwBvVlE1PtyzXr759+7qIfNOqVe7jxrkffLC7mTu4H3ig+wMPuP/3v7XY0X33uW+zTdjJNtuE+QxmznQ/9tiweocO7qNGuS9dWp93IrkClHua79UkVwQrzexYM2sRX8cCqyrzSC6Sk4gkM3MmjBwJXbrAccfBnDnhhu7772+oFmrbthY7HDoU5s8P9Urz54f5DEpKYNy4cN9BVwhNV5JEMBQ4EfgE+DhOn2BmGwNn5TA2EanGF1/AmDGw++7Qq1eo899/f3jmmZAArr46fBk3pEwJYamGnGr01GpIpAlwh5dfDq1+Hn44PANQUhJa/ZxwAnTqlO8Iv2nWLPjFL0Ks7duHVkbnngubb57vyApXfVsNbWtmL5jZrDjfy8wuT7Dd3Wb2SeV21Sw3M7vZzOaa2Qwz27W69UQK2UcfwfXXw/bbh2ae48eHL/7Jk8Ov75EjG18SgOqvELp31xVCY5WkauhPwKXAagB3nwEMSbDdWGBghuWHAj3jazhwe4J9ijR7a9bAE0/A4MHQtStccglssQXcc09o9nnHHaFaKAzs17gpITQNSRJBO3efUqVsTU0buftLQKbbRYOBv8Qb2q8BHc2sc4J4RJqld94JX/pbbx2SwOTJcP758NZbMGkSDBsGm2yS7yjrJt09hFGjlBAagySJ4FMz+z6xhZCZHQMszrxJIl2AhSnzFbHsW8xsuJmVm1n5kiVLsnBokcZhxQr4859Dtc9228Hvfhd+7T/2GCxcGKqFttsu31FmT2VCmDkTDj8cfv1rJYTGIEkiOBO4A9jezBYBI4EzchlUVe4+xt1L3b20qKioIQ8tknXuMGVKeMirc+fwS/+jj8KX4sKF8Pjj4Yqgdet8R5o7O+0EDz6ohNBY1JgI3H2eux8IFAHbu/te7j4/C8deBGydMt81lok0S59+CjfeGJp89usXung46ij45z/h7bdDtVDnAqscVUJoHJK0GmpjZj8BzgHONbOfm9nPs3DsJ4CTYuuhPYBl7p6NKieRRmPtWnj2WTj22PDQ17nnwsYbw+jR4cZvZbVQU7jxm0tKCDUoKwsnpEWL8LesLKu7r/E5AjN7hjB28VRgbWW5u/++hu0eAAYAnQgPol0JtI7bjjYzA24ltCxaCfzU3Wt8QEDPEUhTMH9+aOUzdmzo+2ezzeDEE0O7/513znd0jd/s2eE5hIceCjfIK3s7LcjnEMrKYPjw8PBIpXbtwlOFNTz5nSrTcwRJEsEsdy9JfLQcUyKQxmrVqnCT96674IUXQtlBB4Uv/8GDoU2bvIbXJCkhEK4APvjg2+XbbBN+cSRUrwfKgFfMTL9hRNJ4443wBdWlCxx/PLz7Llx5ZejuobJaSEmgblRlRLikrE15HSRJBHsBU83s7fgE8Ewzm5G1CESaoGXLQj3/brtB797hIa+DDgpdPs+bFxLBNtvkO8rmI11CuOyycBO+WcswTkS2JEkElU8AH0wYonIQGqpSCpB7aOFz0kmhdc8ZZ8DXX8NNN8GHH4YvqoMOShnoRbKuakK47rrwpHKzTggJxomot3T9UzfWl8YjkIa2aJH7r37l/j//E/r533RT99NPd//3v93Xrct3dIVt9mz3IUPCeAjt27tfeqn7kiX5jioHajlORHXIMB6Beh8Vqcbq1TBhQrjx+9RToXv+ffYJN36POebbP9Akv958M9xUHjcu3FQ+++xwU7kxdsiXL/W9WSxSMN5+Gy66KPT3c9RRMHVqmH/nnQ3VQkoCjc+OO8IDD4TurwcNKpAqoyxSIpCC9+WXoc3/XnuF7p5vuAH22CP0ALpgQbgx2bNnvqOUJJQQ6kaJQAqSO7z2Gvzv/4Ybv6ecAkuWhE7eKirC8wBHHAGtWuU7UqkLJYTaUSKQgrJkSfjFX1ICe+4J998f6vwnTQrdPV90EWy5Zb6jlGxRQkhGiUCavbVr4emnwxd+ly6hj/8OHcIT+osXb6gWKvT+fpozJYTMlAik2Xr/fbjiivDg0WGHhZu9Z50Vvgwqq4U23TTfUUpDqi4hFBfDpZcWdkJQIpBmZdWqUN1zwAHQo0d45qakJAyivmhRqBbaaad8Ryn5lpoQfvjDcG+okBOCEoE0C6+/Hn7td+4cOmScNw+uuSb01VVZLbTRRvmOUhqbHXcMPxxmzy7shKBEIE3W55/DbbdB376w665w551hPNy//x3eey9UC229dY27EWGHHQo7ISgRSJOybh28+CKccEL49X/mmaHslltCfz+V1ULq70fqolATQk7/uZjZwNhr6Vwzu6Sa5d3M7EUzez32bHpYLuORpmvRolDf37Mn7L8/PPlkaPs/deqGaqHNNst3lNJcFFpCyFkiMLOWwB8JvZfuCBxvZjtWWe1y4CF37wMMAW7LVTzS9Hz9NYwfH3qZ7NYNLr88/L333tDs849/DFVCIrlSKAkhl1cEuwNz3X2eu38NPAgMrrKOA5UN+L4DfJjDeKSJmDMHLrgAunaFH/0Ipk8PA7vPnbuhWmjjjfMdpRSS5p4QcpkIugALU+YrYlmqq4ATzKwCeAo4u7odmdlwMys3s/IlS5bkIlbJsy+/DD199u8fWnLcdFN4yOvJJ0PLn2uvhe9/P99RSqFLlxAuuSQ8td5U5fuW2vHAWHfvChwG3Gtm34rJ3ce4e6m7lxYVFTV4kJIb7vDKK6Fr5y23hNNOg88+g9/+NvT3U1ktpP5+pLGpmhB+85vwpHJTTQi5TASLgNTGe11jWapTgYcA3P1VoC2gHsQLxKpV4YnfcePguOPgX/8K/cpfcAF873v5jk6kZqkJYfDgppsQcpkI/g30NLPuZrYR4WbwE1XWWQAcAGBmOxASQRM6fVIfG28cBn1ZvHhDtZD6+5GmaIcdoKys6SaEnCUCd18DnAU8C8whtA6abWbXmNkP42rnA/9rZm8ADwDDvKkNmSb10r9/6ABOpDloqglBQ1WKiOTInDnwy1+Gfo3atQvPu5x/PuTjVqeGqhQRyYOmcoWgRCAikmOVCeHNNxtnQlAiEBFpINtv3zgTghKBiEgDS5cQLr44PwlBiUBEJE9SE8KRR8LvfheeVG7ohKBEICKSZ9tvD/fdF24qH3VUwycEJQIRkUYiXwlBiUBEpJFJlxDGjMnN8ZQIREQaqaoJoXv33BxH/TqKiDRylQkhV3RFICJS4JQIREQKnBKBiEiBUyIQESlwSgQiIgVOiUBEpMApEYiIFDglAhGRAtfkhqo0syXAB3XcvBPwaRbDyaWmEqvizL6mEqvizK5cx7mNu1c7SGaTSwT1YWbl6cbsbGyaSqyKM/uaSqyKM7vyGaeqhkRECpwSgYhIgSu0RJCjTlxzoqnEqjizr6nEqjizK29xFtQ9AhER+bZCuyIQEZEqlAhERApcwSQCMxtoZm+b2Vwzu6SBjrm1mb1oZm+a2WwzOyeWX2Vmi8xsenwdlrLNpTHGt83skJriN7PuZjY5lo8zs43qGOt8M5sZ4ymPZZuZ2fNm9m78+91YbmZ2czzmDDPbNWU/J8f13zWzk1PK+8b9z43bWh1i3C7lnE03sy/MbGRjOZ9mdreZfWJms1LKcn4O0x2jlnH+1szeirE8amYdY3mxmf035dyOrms8md5zLeLM+WdtZm3i/Ny4vDhTnBliHZcS53wzm57vc5qWuzf7F9ASeA/oAWwEvAHs2ADH7QzsGqc7AO8AOwJXARdUs/6OMbY2QPcYc8tM8QMPAUPi9GjgjDrGOh/oVKXsN8AlcfoS4Po4fRjwNGDAHsDkWL4ZMC/+/W6c/m5cNiWua3HbQ7PwmX4EbNNYziewD7ArMKshz2G6Y9QyzoOBVnH6+pQ4i1PXq7KfWsWT7j3XMs6cf9bAz4DRcXoIMK4un32V5b8Hfp7vc5ruVShXBLsDc919nrt/DTwIDM71Qd19sbtPi9PLgTlAlwybDAYedPev3P19YC4h9mrjj78W9gf+Grf/M3BkFt/C4LjPqvseDPzFg9eAjmbWGTgEeN7d/+PunwHPAwPjsk3d/TUP//f+JQtxHgC85+6ZnjJv0PPp7i8B/6kmhlyfw3THSBynuz/n7mvi7GtA10z7qGM86d5z4jgzyOZnnRr/X4EDKn+Z1yXWuO2xwAOZ9tEQ5zSdQkkEXYCFKfMVZP5Czrp4edkHmByLzoqXcnenXMqnizNd+ebA5yn/gOvzvhx4zsymmtnwWPY9d18cpz8CvlfHOLvE6arl9TGEb/7Damzns1JDnMN0x6irUwi/Mit1N7PXzeyfZrZ3Svy1jSdb/w5z/Vmv3yYuXxbXr6u9gY/d/d2UskZ1TgslEeSVmbUHHgFGuvsXwO3A94HewGLCZWO+7eXuuwKHAmea2T6pC+MvlEbR1jjW5f4QeDgWNcbz+S0NcQ7rewwzGwWsAcpi0WKgm7v3Ac4D7jezTRsqnmo0ic+6iuP55o+WxnZOCyYRLAK2TpnvGstyzsxaE5JAmbuPB3D3j919rbuvA/5EuHzNFGe68qWES8FWVcprzd0Xxb+fAI/GmD6uvMyMfz+pY5yL+GZVQ33P/6HANHf/OMbc6M5nioY4h+mOUStmNgwYBAyNXzbEqpalcXoqob592zrGU+9/hw30Wa/fJi7/Tly/1uL2RwPjUt5DozqnUDiJ4N9Az9hKYCNCtcITuT5orBu8C5jj7jeklKfW4R0FVLY0eAIYElstdAd6Em4eVRt//Mf6InBM3P5k4PE6xLmJmXWonCbcOJwV46lstZK67yeAk2KLhT2AZfGy9VngYDP7brxkPxh4Ni77wsz2iOfkpLrEmeIbv7Aa2/msoiHOYbpjJGZmA4GLgB+6+8qU8iIzaxmnexDO4bw6xpPuPdcmzob4rFPjPwb4R2VirIMDgbfcfX2VT2M7p0BhtBryDXfX3yFk31ENdMy9CJdwM4Dp8XUYcC8wM5Y/AXRO2WZUjPFtUlrWpIuf0BpiCuHm2MNAmzrE2YPQmuINYHbl/gn1oi8A7wJ/BzaL5Qb8McYyEyhN2dcpMZa5wE9TyksJ/2jfA24lPtVeh1g3Ifw6+05KWaM4n4TktBhYTairPbUhzmG6Y9QyzrmEuubK/08rW838KP4/MR2YBhxR13gyvedaxJnzzxpoG+fnxuU96vLZx/KxwOlV1s3bOU33UhcTIiIFrlCqhkREJA0lAhGRAqdEICJS4JQIREQKnBKBiEiBUyKQrDOziWaW80G4zWyEmc0xs7Iq5b0tpVfKWuxvKzP7a4L1nrLYO2dzYGYDzOzJfMch+dOq5lVEGo6ZtfIN/b/U5GfAgZ7ysE7Um9Ae+6na7N/dP2TDA0ZpuXutk4xIY6YrggJloU/0OWb2JwtjJTxnZhvHZet/0ZtZJzObH6eHmdljFvpDn29mZ5nZeRY6z3rNzDZLOcSJFvpan2Vmu8ftN7HQUdiUuM3glP0+YWb/IDw0UzXW8+J+ZpnZyFg2mvBA0NNmdm7KuhsB1wDHxeMfZ6EP+3vN7F/AvfG9TzKzafHVP+WczEqJabyZPWOhD/jfpBxjfjwvmc7hbhY6Rptuoa//9f3UV3lvF5rZv+O6V8eyo8zshfi0aGcze8fMtswQ9wALnZc9bmbzzOw6Mxsaz/NMM/t+XG+smY02s/K4z0HVxJPuM9oplk2Psfassl3LuP9Z8ZjnxvLvx3M4Nca+fSwvMrNH4nv/t5n9IJZfFY8/Mb6XEdWdN8myujyFplfTfxH6RF8D9I7zDwEnxOmJxCcUgU7A/Dg9jPC0ZQegiNAr4+lx2R8InepVbv+nOL0Pse914Fcpx+hIeNpzk7jfCqp5IhboS3hichOgPeGJzD5x2XyqjKGQEuetKfNXAVOBjeN8O6BtnO4JlKeck1kp+5hH6GemLfABsHXqcWs4h7OAPeP0dVTT/zyh+4gxhKdDWwBPAvvEZfcBZ8Wy42uIewDwOWH8izaEvmaujsvOAW6M02OBZ+KxesZz3jZu/2QNn9EthD6IIPTrv3E1n9PzKfMd498XgJ5xuh+huwaA+wkdHQJ0I3TDUvlZvRLfRyfCU+St8/3vpbm/VDVU2N539+lxeirhi60mL3oYW2G5mS0D/hbLZwK9UtZ7AEI/7Wa2qYU69YOBH5rZBXGdtoQvAYh98FdzvL2AR919BYCZjSd06/t6glhTPeHu/43TrYFbzaw3sJbQ4Vd1XnD3ZfG4bxIGwVlYZZ1vncP4Xju4+6ux/H5CZ25VHRxfle+lPeEL+iXgbEIyec3dK/tVyhT3vz32MWNm7wHPxfKZwH4p6z3kocO2d81sHrB9NTFV9xm9Cowys67AeP9ml8oQkmYPM7sFmEDo0rw90B942DZ0598m/j0Q2DGlfNO4PsAEd/8K+MrMPiF0uVy1+k+ySImgsH2VMr0W2DhOr2FDtWHbDNusS5lfxzf/f6rad4kTfvn+yN3fTl1gZv2AFbWKvPZS938u8DGwC+F9rkqzTdXzU92/l3TnMAkDfu3ud1SzrCvhnH7PzFrEL+9Mcdfnc6ka07c+I2COmU0GDgeeMrP/c/d/rN+J+2dmtgthYJ3TCQOxjCT0+d+7mvfXAtjD3b9x7mNiSHLeJYt0j0CqM59wqQ8Jbp6mcRyAme1F6BFxGaFnzbPN1o/D2ifBfiYBR5pZOws9ox4VyzJZTqi+Suc7wOL45XoiYTjDrHH3zwlXTP1i0ZA0qz4LnFL5S9jMupjZFha6Lr6b0MvqHEKf9dmK+8dm1iLeN+hB6KCtakzf+ows9JI5z91vJvR8mXr1h5l1Alq4+yPA5YQhWr8A3jezH8d1LCYLCFcsZ6ds37sO70WyRIlAqvM74Awze51QT1sXq+L2owm9RgL8glC9McPMZsf5jDwM9TmW0AvkZOBOd6+pWuhFQrXDdDM7rprltwEnm9kbhKqRXFyNnAr8ycKA5ZsQ7qd8g7s/R6g2etXMZhKGRewAXAZMcveXCUngNDPbIUtxLyCcy6cJ93eqXg2l+4yOBWbF91NCGEYxVRdgYlx+H3BpLB8KnBpjns2GIWJHAKXxxvObhKsIyRP1PiqSA2bW3t2/jNOXELpLPifPMY0l3BSu8VkJKSyqexPJjcPN7FLCv7EPCK2QRBolXRGIiBQ43SMQESlwSgQiIgVOiUBEpMApEYiIFDglAhGRAvf/tjwjqxXuS1UAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "plt.plot(train_counter, train_losses, color='blue')\n", "plt.scatter(test_counter, test_losses, color='red')\n", "plt.legend(['Train Loss', 'Test Loss'], loc='upper right')\n", "plt.xlabel('number of training examples seen')\n", "plt.ylabel('negative log likelihood loss')\n", "#fig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cv:Python", "language": "python", "name": "conda-env-cv-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }