{ "cells": [ { "cell_type": "markdown", "id": "83a57396", "metadata": {}, "source": [ "## Training Notebook\n", "\n", "This notebook illustrates training of a simple model to classify digits using the MNIST dataset. This code is used to train the model included with the templates. This is meant to be a starter model to show you how to set up Serverless applications to do inferences. For deeper understanding of how to train a good model for MNIST, we recommend literature from the [MNIST website](http://yann.lecun.com/exdb/mnist/). The dataset is made available under a [Creative Commons Attribution-Share Alike 3.0](https://creativecommons.org/licenses/by-sa/3.0/) license." ] }, { "cell_type": "code", "execution_count": 2, "id": "de99c6fc", "metadata": {}, "outputs": [], "source": [ "# We'll use scikit-learn to load the dataset\n", "\n", "! pip install -q scikit-learn==0.23.2" ] }, { "cell_type": "code", "execution_count": 25, "id": "b6bf6336", "metadata": {}, "outputs": [], "source": [ "# Load the mnist dataset\n", "\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split\n", "\n", "X, y = fetch_openml('mnist_784', return_X_y=True)\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=10000)" ] }, { "cell_type": "markdown", "id": "48d0a541", "metadata": {}, "source": [ "## Tensorflow Model Training\n", "\n", "For this example, we will train a simple CNN classifier using Tensorflow to classify the MNIST digits. We will then freeze the model in the `.h5` format. This is same as the starter model file included with the SAM templates." ] }, { "cell_type": "code", "execution_count": 27, "id": "8205faa8", "metadata": {}, "outputs": [], "source": [ "! pip install -q tensorflow==2.8.0" ] }, { "cell_type": "code", "execution_count": 28, "id": "7fdaebfb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using TesorFlow version 2.4.0\n", "Epoch 1/15\n", "1875/1875 [==============================] - 13s 6ms/step - loss: 0.4003 - accuracy: 0.8761\n", "Epoch 2/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0878 - accuracy: 0.9731\n", "Epoch 3/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0655 - accuracy: 0.9792\n", "Epoch 4/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0538 - accuracy: 0.9840\n", "Epoch 5/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0459 - accuracy: 0.9854\n", "Epoch 6/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0377 - accuracy: 0.9882\n", "Epoch 7/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0362 - accuracy: 0.9884\n", "Epoch 8/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0292 - accuracy: 0.9909\n", "Epoch 9/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0272 - accuracy: 0.9915\n", "Epoch 10/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0268 - accuracy: 0.9915\n", "Epoch 11/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0177 - accuracy: 0.9944\n", "Epoch 12/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0219 - accuracy: 0.9934\n", "Epoch 13/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0161 - accuracy: 0.9952\n", "Epoch 14/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0181 - accuracy: 0.9944\n", "Epoch 15/15\n", "1875/1875 [==============================] - 12s 6ms/step - loss: 0.0161 - accuracy: 0.9947\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "\n", "print (f'Using TesorFlow version {tf.__version__}')\n", "\n", "# Reshape the flat input into a 28x28x1 dim tensor\n", "X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n", "X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)\n", "\n", "# Convert the output tensors to integers (our data is read as Strings)\n", "y_train = y_train.astype(np.int8)\n", "y_test = y_test.astype(np.int8)\n", "\n", "model = tf.keras.Sequential([\n", " # Input layer to match the shape above\n", " tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n", " tf.keras.layers.MaxPooling2D((2, 2)),\n", " \n", " tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),\n", " tf.keras.layers.MaxPooling2D((2, 2)),\n", " \n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.BatchNormalization(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.BatchNormalization(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.BatchNormalization(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.BatchNormalization(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.BatchNormalization(),\n", " \n", " # Output layer for 10 classes\n", " tf.keras.layers.Dense(10)\n", "])\n", "\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "model.fit(X_train, y_train, epochs=15)" ] }, { "cell_type": "code", "execution_count": 29, "id": "42f2375d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "313/313 - 1s - loss: 0.0404 - accuracy: 0.9909\n", "\n", "Test accuracy: 0.9908999800682068\n" ] } ], "source": [ "test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)\n", "\n", "print('\\nTest accuracy:', test_acc)" ] }, { "cell_type": "code", "execution_count": 30, "id": "eb22ad91", "metadata": {}, "outputs": [], "source": [ "# Save model to the disk\n", "model.save('tf_digit_classifier.h5')" ] }, { "cell_type": "code", "execution_count": null, "id": "620bcd9a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }