{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Building your first Artificial Neural Network with AWS\n", "\n", "#### Predicting fashion type using Zalando's Fasion-MNIST dataset (https://github.com/zalandoresearch/fashion-mnist)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "Copyright [2017]-[2017] Amazon.com, Inc. or its affiliates. All Rights Reserved.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at\n", "\n", "http://aws.amazon.com/apache2.0/\n", "\n", "or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mxnet as mx\n", "import mxnet.notebook.callback\n", "import numpy as np\n", "import os\n", "import urllib\n", "import gzip\n", "import struct\n", "import math\n", "import cv2\n", "import scipy.misc\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare training and test datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def download_data(url, force_download=True): \n", " fname = url.split(\"/\")[-1]\n", " if force_download or not os.path.exists(fname):\n", " urllib.urlretrieve(url, fname)\n", " return fname\n", "\n", "def to4d(img):\n", " return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255\n", "\n", "def read_data(label, image):\n", " base_url = 'https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/'\n", " with gzip.open(download_data(base_url+label, os.path.join('data',label))) as flbl:\n", " magic, num = struct.unpack(\">II\", flbl.read(8))\n", " label = np.fromstring(flbl.read(), dtype=np.int8)\n", " with gzip.open(download_data(base_url+image, os.path.join('data',image)), 'rb') as fimg:\n", " magic, num, rows, cols = struct.unpack(\">IIII\", fimg.read(16))\n", " image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)\n", " return (label, image)\n", "\n", "batch_size = 100\n", "(train_lbl, train_img) = read_data('train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')\n", "(val_lbl, val_img) = read_data('t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')\n", "train_data_iter = mx.io.NDArrayIter(data={'fashion_data': to4d(train_img)}, label= {'fashion_item_label': train_lbl}, batch_size=100, shuffle=True)\n", "test_data_iter = mx.io.NDArrayIter(data={'fashion_data': to4d(val_img)}, label= {'fashion_item_label': val_lbl}, batch_size=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Display example training data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for i in range(10):\n", " plt.subplot(1,10,i+1)\n", " dsp_img= cv2.bitwise_not(train_img[i])\n", " plt.imshow(dsp_img, cmap='Greys_r')\n", " plt.axis('off')\n", "\n", "plt.show()\n", "print('label: %s' % (train_lbl[0:10],))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Zalando fashion labels https://github.com/zalandoresearch/fashion-mnist\n", "\n", "fashion_labels=['T-shirt/top',\n", "'Trouser',\n", "'Pullover',\n", "'Dress',\n", "'Coat',\n", "'Sandal',\n", "'Shirt',\n", "'Sneaker',\n", "'Bag',\n", "'Ankle boot']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build MXNet model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fashion_item_label = mx.symbol.Variable('fashion_item_label')\n", "# input\n", "data = mx.symbol.Variable('fashion_data')\n", "# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)\n", "data = mx.sym.flatten(data=data, name='flatten')\n", "\n", "# 1st fully-connected layer + activation function\n", "fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)\n", "act1 = mx.sym.Activation(data=fc1, act_type=\"relu\")\n", "# 2nd fully-connected layer + activation function\n", "fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64)\n", "act2 = mx.sym.Activation(data=fc2, act_type=\"relu\")\n", "# 3rd fully connected layer (MNIST uses 10 classes)\n", "fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)\n", "# softmax with cross entropy loss\n", "mlp = mx.sym.SoftmaxOutput(data = fc3, label = fashion_item_label, name='softmax')\n", "\n", "mx.viz.plot_network(mlp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model and commit checkpoints" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import logging\n", "\n", "logging.basicConfig(level=logging.INFO)\n", "logging.getLogger().setLevel(logging.INFO)\n", "\n", "ctx = mx.gpu()\n", "\n", "mod = mx.mod.Module(symbol=mlp, data_names=['fashion_data'], label_names=['fashion_item_label'], context=ctx, logger=logging)\n", "mod.bind(data_shapes=train_data_iter.provide_data, label_shapes=train_data_iter.provide_label)\n", "mod.init_params(initializer=mx.init.Xavier(magnitude=2.))\n", " \n", "mod.fit(train_data_iter, # train data\n", " eval_data=test_data_iter, # validation data\n", " optimizer='sgd', # use SGD to train\n", " optimizer_params={'learning_rate' : 0.1}, # use fixed learning rate\n", " eval_metric=mx.metric.Accuracy(), # report accuracy during training\n", " num_epoch=10, # train for at most 10 dataset passes\n", " epoch_end_callback = mx.callback.do_checkpoint('fashion_mnist')) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run predictions for 10 example elements" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "pred_data_iter = mx.io.NDArrayIter(data={'fashion_data': to4d(val_img)[0:100]}, batch_size=100)\n", "pred_digits = mod.predict(eval_data=pred_data_iter).asnumpy()\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "for i in range(10):\n", " plt.subplot(1,10,i+1)\n", " plt.imshow(val_img[i + 10], cmap='Greys')\n", " plt.axis('off')\n", "plt.show()\n", "\n", "for x in range(10, 20):\n", " print(\"Predicted fashion label for image %s is %s \" % (x, np.where(pred_digits[x,0:10] == pred_digits[x,0:10].max())[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Downloading images for prediction from amazon.com" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "!wget -O predict1.jpg https://images-na.ssl-images-amazon.com/images/I/81OaXwn1x4L._UX679_.jpg\n", "!wget -O predict2.jpg https://images-eu.ssl-images-amazon.com/images/I/31TcgNHsbIL._AC_UL260_SR200,260_.jpg\n", "!wget -O predict3.jpg https://images-eu.ssl-images-amazon.com/images/I/41hWhZBIc3L._AC_UL260_SR200,260_.jpg\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load model from checkpoint for prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "prediction_model_check_point = 10\n", "prediction_model_prefix = 'fashion_mnist'\n", "prediction_sym, arg_params, aux_params = mx.model.load_checkpoint(prediction_model_prefix, prediction_model_check_point)\n", "prediction_model = mx.mod.Module(symbol=prediction_sym, data_names=['fashion_data'], label_names=['fashion_item_label'])\n", "prediction_model.bind(for_training=False, data_shapes=[('fashion_data', (1,1,28,28))])\n", "prediction_model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)\n", "\n", "# define prediction function\n", "def predict_fashion(img):\n", " # format data to run prediction\n", " array = np.full((1, 28, 28), img, dtype=np.float32)\n", " array.shape\n", " pred_data_iter = mx.io.NDArrayIter(data={'fashion_data': to4d(array)}, batch_size=1)\n", " \n", " pred_digits = prediction_model.predict(eval_data=pred_data_iter).asnumpy()\n", "\n", " label = (np.where(pred_digits[0] == pred_digits[0].max())[0])\n", " \n", " print(\"Predicted fashion label for image is %s (%s) \" % (label,fashion_labels[label[0]]))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict labels for downloaded images" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for i in xrange(3):\n", " img = mpimg.imread('predict'+str(i+1)+'.jpg')\n", " plt.imshow(img)\n", " plt.axis('off')\n", " plt.show() \n", " # get colours in line with train data\n", " img = cv2.bitwise_not(img)\n", " img= np.array (np.mean(img, -1))\n", "\n", " # resize image\n", " img = scipy.misc.imresize(img, (28, 28))\n", "\n", "\n", " predict_fashion(img)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }