{ "cells": [ { "cell_type": "markdown", "id": "501d2b94", "metadata": {}, "source": [ "## Training Notebook\n", "\n", "This notebook illustrates training of a simple model to classify digits using the MNIST dataset. This code is used to train the model included with the templates. This is meant to be a started model to show you how to set up Serverless applications to do inferences. For deeper understanding of how to train a good model for MNIST, we recommend literature from the [MNIST website](http://yann.lecun.com/exdb/mnist/). The dataset is made available under a [Creative Commons Attribution-Share Alike 3.0](https://creativecommons.org/licenses/by-sa/3.0/) license." ] }, { "cell_type": "code", "execution_count": 1, "id": "528eaf05", "metadata": {}, "outputs": [], "source": [ "# We'll use scikit-learn to load the dataset\n", "\n", "! pip install -q scikit-learn==0.23.2" ] }, { "cell_type": "code", "execution_count": 2, "id": "927e21aa", "metadata": {}, "outputs": [], "source": [ "# Load the mnist dataset\n", "\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split\n", "\n", "X, y = fetch_openml('mnist_784', return_X_y=True)\n", "\n", "# We limit training to 10000 images for faster training. Remove train_size to use all examples.\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1000, train_size=10000)" ] }, { "cell_type": "code", "execution_count": 6, "id": "ec45d7e2", "metadata": {}, "outputs": [], "source": [ "# Next, let's add code for deskewing images (we will use this to improve accuracy)\n", "# This code comes from https://fsix.github.io/mnist/Deskewing.html\n", "\n", "from scipy.ndimage import interpolation\n", "\n", "def moments(image):\n", " c0, c1 = np.mgrid[:image.shape[0], :image.shape[1]]\n", " img_sum = np.sum(image)\n", " \n", " m0 = np.sum(c0 * image) / img_sum\n", " m1 = np.sum(c1 * image) / img_sum\n", " m00 = np.sum((c0-m0)**2 * image) / img_sum\n", " m11 = np.sum((c1-m1)**2 * image) / img_sum\n", " m01 = np.sum((c0-m0) * (c1-m1) * image) / img_sum\n", " \n", " mu_vector = np.array([m0,m1])\n", " covariance_matrix = np.array([[m00, m01],[m01, m11]])\n", " \n", " return mu_vector, covariance_matrix\n", "\n", "\n", "def deskew(image):\n", " c, v = moments(image)\n", " alpha = v[0,1] / v[0,0]\n", " affine = np.array([[1,0], [alpha,1]])\n", " ocenter = np.array(image.shape) / 2.0\n", " offset = c - np.dot(affine, ocenter)\n", "\n", " return interpolation.affine_transform(image, affine, offset=offset)\n", "\n", "\n", "def deskew_images(images):\n", " output_images = []\n", " \n", " for image in images:\n", " output_images.append(deskew(image.reshape(28, 28)).flatten())\n", " \n", " return np.array(output_images)\n" ] }, { "cell_type": "markdown", "id": "f3059144", "metadata": {}, "source": [ "## Scikit-learn Model Training\n", "\n", "For this example, we will train a simple SVM classifier using scikit-learn to classify the MNIST digits. We will then freeze the model in the `.joblib` format. This is same as the starter model file included with the SAM templates." ] }, { "cell_type": "code", "execution_count": 4, "id": "6f1ad87c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using scikit-learn version: 0.23.2\n", "Test accuracy without deskewing: 0.959\n", "CPU times: user 27.1 s, sys: 46.7 ms, total: 27.2 s\n", "Wall time: 27.2 s\n" ] } ], "source": [ "%%time\n", "\n", "import sklearn\n", "import numpy as np\n", "\n", "from sklearn.metrics import accuracy_score\n", "from sklearn import svm\n", "\n", "print (f'Using scikit-learn version: {sklearn.__version__}')\n", "\n", "# Fit our training data\n", "clf = svm.SVC(degree=5)\n", "clf.fit(X_train, y_train)\n", "\n", "# Test the fitted model for accuracy for the accuracy score\n", "accuracy = accuracy_score(y_test, clf.predict(X_test))\n", "\n", "print('Test accuracy without deskewing:', accuracy)" ] }, { "cell_type": "code", "execution_count": 9, "id": "6298da9d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy with deskewing: 0.971\n", "CPU times: user 22.7 s, sys: 19.9 ms, total: 22.7 s\n", "Wall time: 22.7 s\n" ] } ], "source": [ "%%time\n", "\n", "# Let's try this again with deskewing on\n", "\n", "# Fit our training data\n", "clf = svm.SVC(degree=5)\n", "clf.fit(deskew_images(X_train), y_train)\n", "\n", "# Test the fitted model for accuracy for the accuracy score\n", "accuracy = accuracy_score(y_test, clf.predict(deskew_images(X_test)))\n", "\n", "print('Test accuracy with deskewing:', accuracy)" ] }, { "cell_type": "code", "execution_count": 19, "id": "1fa8f89f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['digit_classifier.joblib']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import joblib\n", "\n", "# Save the model to disk with compression to keep size low\n", "joblib.dump(clf, 'digit_classifier.joblib', compress=3)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9774df7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }