{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Amazon SageMaker - Bring Your Own Model \n", "## PyTorch ç·¨\n", "\n", "ã“ã“ã§ã¯ [PyTorch](https://pytorch.org/) ã®ã‚µãƒ³ãƒ—ルコードをAmazon SageMaker 上ã§å®Ÿè¡Œã™ã‚‹ãŸã‚ã®ç§»è¡Œæ‰‹é †ã«ã¤ã„ã¦èª¬æ˜Žã—ã¾ã™ã€‚SageMaker Python SDK ã§ PyTorch を使ã†ãŸã‚ã®èª¬æ˜Žã¯ [SDK ã®ãƒ‰ã‚ュメント](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html) ã«ã‚‚多ãã®æƒ…å ±ãŒã‚りã¾ã™ã€‚" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—ãƒˆã®æ›¸ãæ›ãˆ\n", "\n", "### æ›¸ãæ›ãˆãŒå¿…è¦ãªç†ç”±\n", "Amazon SageMaker ã§ã¯ã€ã‚ªãƒ–ジェクトストレージ Amazon S3 をデータä¿ç®¡ã«åˆ©ç”¨ã—ã¾ã™ã€‚例ãˆã°ã€S3 上ã®å¦ç¿’データを指定ã™ã‚‹ã¨ã€è‡ªå‹•的㫠Amazon SageMaker ã®å¦ç¿’用インスタンスã«ãƒ‡ãƒ¼ã‚¿ãŒãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã•れã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—トãŒå®Ÿè¡Œã•れã¾ã™ã€‚トレーニングスクリプトを実行ã—ãŸå¾Œã«ã€æŒ‡å®šã—ãŸãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã«ãƒ¢ãƒ‡ãƒ«ã‚’ä¿å˜ã™ã‚‹ã¨ã€è‡ªå‹•çš„ã«ãƒ¢ãƒ‡ãƒ«ãŒS3ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã•れã¾ã™ã€‚\n", "\n", "トレーニングスクリプトを SageMaker ã«æŒã¡è¾¼ã‚€å ´åˆã¯ã€ä»¥ä¸‹ã®ç‚¹ã‚’ä¿®æ£ã™ã‚‹å¿…è¦ãŒã‚りã¾ã™ã€‚\n", "- å¦ç¿’用インスタンスã«ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã•れãŸå¦ç¿’データã®ãƒãƒ¼ãƒ‰\n", "- å¦ç¿’ãŒå®Œäº†ã—ãŸã¨ãã®ãƒ¢ãƒ‡ãƒ«ã®ä¿å˜\n", "\n", "ã“れらã®ä¿®æ£ã¯ã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—トを任æ„ã®ç’°å¢ƒã«æŒã¡è¾¼ã‚€éš›ã®ä¿®æ£ã¨å¤‰ã‚らãªã„ã§ã—ょã†ã€‚例ãˆã°ã€è‡ªèº«ã®PCã«æŒã¡è¾¼ã‚€å ´åˆã‚‚ã€`/home/user/data` ã®ã‚ˆã†ãªãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã‹ã‚‰ãƒ‡ãƒ¼ã‚¿ã‚’èªã¿è¾¼ã‚“ã§ã€`/home/user/model` ã«ãƒ¢ãƒ‡ãƒ«ã‚’ä¿å˜ã—ãŸã„ã¨è€ƒãˆã‚‹ã‹ã‚‚ã—れã¾ã›ã‚“ã€‚åŒæ§˜ã®ã“ã¨ã‚’ SageMaker ã§è¡Œã†å¿…è¦ãŒã‚りã¾ã™ã€‚\n", "\n", "### æ›¸ãæ›ãˆã‚‹å‰ã«ä¿å˜å…ˆã‚’決ã‚ã‚‹\n", "\n", "ã“ã®ãƒãƒ³ã‚ºã‚ªãƒ³ã§ã¯ã€S3ã‹ã‚‰ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã™ã‚‹å¦ç¿’データ・ãƒãƒªãƒ‡ãƒ¼ã‚·ãƒ§ãƒ³ãƒ‡ãƒ¼ã‚¿ã¨ã€S3ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã™ã‚‹ãƒ¢ãƒ‡ãƒ«ã¯ã€ãれãžã‚Œä»¥ä¸‹ã®ã‚ˆã†ã«å¦ç¿’用インスタンスã«ä¿å˜ã™ã‚‹ã“ã¨ã«ã—ã¾ã™ã€‚`/opt/ml/input/data/train/`ã¨ã„ã£ãŸãƒ‘スã«è¨å®šã™ã‚‹ã“ã¨ã¯å¥‡ç•°ã«æ„Ÿã˜ã‚‰ã‚Œã‚‹ã‹ã‚‚ã—れã¾ã›ã‚“ãŒã€ã“れらã¯ç’°å¢ƒå¤‰æ•°ã‹ã‚‰èªã¿è¾¼ã‚“ã§ä½¿ç”¨ã™ã‚‹ã“ã¨ãŒå¯èƒ½ãªãƒ‘スã§ã€ã‚³ãƒ¼ãƒ‡ã‚£ãƒ³ã‚°ã‚’シンプルã«ã™ã‚‹ã“ã¨ãŒã§ãã¾ã™ã€‚[1-1. 環境変数ã®å–å¾—](#env)ã§èªã¿è¾¼ã¿æ–¹æ³•を説明ã—ã¾ã™ã€‚\n", "\n", "#### å¦ç¿’データ\n", "- ç”»åƒ: `/opt/ml/input/data/train/image.npy`\n", "- ラベル: `/opt/ml/input/data/train/label.npy`\n", "\n", "#### ãƒãƒªãƒ‡ãƒ¼ã‚·ãƒ§ãƒ³ãƒ‡ãƒ¼ã‚¿\n", "- ç”»åƒ: `/opt/ml/input/data/test/image.npy`\n", "- ラベル: `/opt/ml/input/data/test/label.npy`\n", "\n", "#### モデル\n", "`/opt/ml/model` 以下ã«ã‚·ãƒ³ãƒœãƒ«ã‚„パラメータをä¿å˜ã™ã‚‹\n", "\n", "### æ›¸ãæ›ãˆã‚‹ç®‡æ‰€\n", "ã¾ãš [サンプルã®ã‚½ãƒ¼ã‚¹ã‚³ãƒ¼ãƒ‰](https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/examples/tutorials/layers/cnn_mnist.py) を以下ã®ã‚³ãƒžãƒ³ãƒ‰ã§ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://raw.githubusercontent.com/pytorch/examples/master/mnist/main.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ダウンãƒãƒ¼ãƒ‰ã•れ㟠`mnist.py` をファイルブラウザã‹ã‚‰è¦‹ã¤ã‘ã¦é–‹ã„ã¦ä¸‹ã•ã„ (JupyterLab ã®å ´åˆã¯å·¦å³ã«ãƒ•ァイルを並ã¹ã‚‹ã¨ä½œæ¥ã—ã‚„ã™ã„ã§ã™)。ã‚ã‚‹ã„ã¯ãŠå¥½ããªã‚¨ãƒ‡ã‚£ã‚¿ãƒ¼ã‚’ãŠä½¿ã„é ‚ã„ã¦ã‚‚çµæ§‹ã§ã™ã€‚ã“ã®`mnist.py`ã¯ã€`def main()`ã®ãªã‹ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—ト内ã§ä»¥ä¸‹ã®é–¢æ•°ã‚’呼ã³å‡ºã—ã€S3以外ã‹ã‚‰ãƒ‡ãƒ¼ã‚¿ã‚’ダウンãƒãƒ¼ãƒ‰ã—ã¦ã„ã¾ã™ã€‚\n", "\n", "```python\n", "dataset1 = datasets.MNIST('../data', train=True, download=True,\n", " transform=transform)\n", "dataset2 = datasets.MNIST('../data', train=False,\n", " transform=transform)\n", "```\n", "\n", "ã“ã†ã„ã£ãŸæ–¹æ³•ã‚‚å¯èƒ½ã§ã™ãŒã€ä»Šå›žã¯S3ã‹ã‚‰å¦ç¿’データをダウンãƒãƒ¼ãƒ‰ã—ã¦ã€å‰è¿°ã—ãŸã‚ˆã†ã«`/opt/ml/input/data/train/`ã¨ã„ã£ãŸãƒ‘スã‹ã‚‰èªã¿å‡ºã—ã¦ä½¿ã„ã¾ã™ã€‚æ›¸ãæ›ãˆã‚‹ç‚¹ã¯ä¸»ã«4点ã§ã™:\n", "\n", "1. 環境変数ã®å–å¾— \n", " SageMaker ã§ã¯ã€ã‚らã‹ã˜ã‚指定ã•れãŸãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã«S3ã‹ã‚‰ãƒ‡ãƒ¼ã‚¿ãŒãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã•れãŸã‚Šã€ä½œæˆã—ãŸãƒ¢ãƒ‡ãƒ«ã‚’ä¿å˜ã—ãŸã‚Šã—ã¾ã™ã€‚ã“れらã®ãƒ‘スを環境変数ã‹ã‚‰èªã¿è¾¼ã‚“ã§ä½¿ç”¨ã™ã‚‹ã“ã¨ãŒå¯èƒ½ã§ã™ã€‚環境変数をèªã¿è¾¼ã‚€ã“ã¨ã§ã€å¦ç¿’データã®ä½ç½®ã‚’トレーニングスクリプト内ã«ãƒãƒ¼ãƒ‰ã‚³ãƒ¼ãƒ‡ã‚£ãƒ³ã‚°ã™ã‚‹å¿…è¦ãŒã‚りã¾ã›ã‚“。もã¡ã‚んパスã®å¤‰æ›´ã¯å¯èƒ½ã§ã€APIçµŒç”±ã§æ¸¡ã™ã“ã¨ã‚‚ã§ãã¾ã™ã€‚\n", " \n", "1. 引数ã®ä¿®æ£ \n", " SageMaker ã§ã¯å¦ç¿’を実行ã™ã‚‹ API ã« hyperparameters ã¨ã„ã†è¾žæ›¸å½¢å¼ã®æƒ…å ±ã‚’æ¸¡ã™ã“ã¨ãŒã§ãã¾ã™ã€‚ã“ã®æƒ…å ±ã¯ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—トã«å¯¾ã™ã‚‹å¼•æ•°ã¨ã—ã¦åˆ©ç”¨ã§ãã¾ã™ã€‚例ãˆã°ã€\n", " ```\n", " hyperparameters = {'epoch': 100}\n", " ```\n", " ã¨æŒ‡å®šã—㦠`main.py` ã‚’å¦ç¿’ã™ã‚‹å ´åˆã¯ã€`python main.py --epoch 100` を実行ã™ã‚‹ã“ã¨ã¨ã»ã¼ç‰ä¾¡ã§ã™ã€‚ãŸã ã—ã€è¾žæ›¸å½¢å¼ã§è¡¨ã›ãªã„引数ã¯ãã®ã¾ã¾ã§ã¯æ‰±ãˆãªã„ãŸã‚ã€æ‰±ãˆã‚‹ã‚ˆã†ä¿®æ£ã™ã‚‹å¿…è¦ãŒã‚りã¾ã™ã€‚ \n", "1. å¦ç¿’データã®ãƒãƒ¼ãƒ‰ \n", " 環境変数をå–å¾—ã—ã¦å¦ç¿’データã®ä¿å˜å…ˆãŒã‚ã‹ã‚Œã°ã€ãã®ä¿å˜å…ˆã‹ã‚‰å¦ç¿’データをãƒãƒ¼ãƒ‰ã™ã‚‹ã‚ˆã†ã«ã‚³ãƒ¼ãƒ‰ã‚’æ›¸ãæ›ãˆã¾ã—ょã†ã€‚\n", "\n", "1. å¦ç¿’済ã¿ãƒ¢ãƒ‡ãƒ«ã®ä¿å˜å½¢å¼ã¨å‡ºåŠ›å…ˆã®å¤‰æ›´ \n", " SageMaker 㯠[PyTorch 用ã®ãƒ¢ãƒ‡ãƒ«ã‚µãƒ¼ãƒ](https://github.com/aws/sagemaker-pytorch-inference-toolkit)ã®ä»•組ã¿ã‚’利用ã—ã¦ãƒ¢ãƒ‡ãƒ«ã‚’ホストã—ã€`.pth` ã¾ãŸã¯ `.pt` ã®å½¢å¼ã® PyTorch モデルを利用ã™ã‚‹ã“ã¨ãŒã§ãã¾ã™ã€‚å¦ç¿’ã—ã¦å¾—られãŸãƒ¢ãƒ‡ãƒ«ã¯ã€æ£ã—ã„ä¿å˜å…ˆã«ä¿å˜ã™ã‚‹å¿…è¦ãŒã‚りã¾ã™ã€‚å¦ç¿’ãŒå®Œäº†ã™ã‚‹ã¨å¦ç¿’用インスタンスã¯å‰Šé™¤ã•れã¾ã™ã®ã§ã€ä¿å˜å…ˆã‚’指定ã®ãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã«å¤‰æ›´ã—ã¦ã€ãƒ¢ãƒ‡ãƒ«ãŒS3ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã•れるよã†ã«ã—ã¾ã™ã€‚" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### <a name=\"env\"></a>1-1. 環境変数ã®å–å¾—\n", "\n", "Amazon SageMaker ã§å¦ç¿’を行ã†éš›ã€å¦ç¿’ã«åˆ©ç”¨ã™ã‚‹ Python スクリプト (今回ã®å ´åˆã¯ PyTorch ã®ã‚¹ã‚¯ãƒªãƒ—ト) ã‚’ã€ãƒŽãƒ¼ãƒˆãƒ–ックインスタンスã¨ã¯ç•°ãªã‚‹å¦ç¿’用インスタンスã§å®Ÿè¡Œã—ã¾ã™ã€‚ãã®éš›ã€ãƒ‡ãƒ¼ã‚¿ãƒ»ãƒ¢ãƒ‡ãƒ«ã®å…¥å‡ºåŠ›ã®ãƒ‘スã¯ã€ [ã“ã¡ã‚‰](https://sagemaker.readthedocs.io/en/stable/using_tf.html#preparing-a-script-mode-training-script) ã«è¨˜è¿°ã•れã¦ã„るよã†ã« `SM_CHANNEL_XXXX` ã‚„ `SM_MODEL_DIR` ã¨ã„ã†ç’°å¢ƒå¤‰æ•°ã‚’å‚ç…§ã—ã¦çŸ¥ã‚‹ã“ã¨ãŒã§ãã¾ã™ã€‚\n", "\n", "\n", "\n", "ã“ã“ã§ã¯ã€å¦ç¿’データã®ãƒ‘ス `SM_CHANNEL_TRAIN`, テストデータã®ãƒ‘ス `SM_CHANNEL_TEST`, モデルã®ä¿å˜å…ˆã®ãƒ‘ス `SM_MODEL_DIR` ã®ç’°å¢ƒå¤‰æ•°ã®å€¤ã‚’å–å¾—ã—ã¾ã™ã€‚`def main():`ã®ç›´ä¸‹ã«ã€ç’°å¢ƒå¤‰æ•°ã‚’å–å¾—ã™ã‚‹ä»¥ä¸‹ã®ã‚³ãƒ¼ãƒ‰ã‚’è¿½åŠ ã—ã¾ã™ã€‚\n", "\n", "```python\n", "def main():\n", " import os\n", " train_dir = os.environ['SM_CHANNEL_TRAIN']\n", " test_dir = os.environ['SM_CHANNEL_TEST']\n", " model_dir = os.environ['SM_MODEL_DIR']\n", "```\n", "\n", "ã“れã§å¦ç¿’データ・ãƒãƒªãƒ‡ãƒ¼ã‚·ãƒ§ãƒ³ãƒ‡ãƒ¼ã‚¿ãƒ»ãƒ¢ãƒ‡ãƒ«ã®ä¿å˜å…ˆã‚’å–å¾—ã™ã‚‹ã“ã¨ãŒã§ãã¾ã—ãŸã€‚次ã«ã“れらã®ãƒ•ァイルを実際ã«èªã¿è¾¼ã‚€å‡¦ç†ã‚’実装ã—ã¾ã™ã€‚" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-2. 引数ã®ä¿®æ£\n", "\n", "辞書形å¼ã§è¡¨ã›ãªã„引数ã¯SageMaker ã®å¦ç¿’実行時ã«ã‚ãŸã™ã“ã¨ã¯ã§ãã¾ã›ã‚“。例ãˆã°ã€`python main.py --save-model` ã¨ã™ã‚‹ã¨ `save_model` ㌠True ã¨ã—ã¦è§£é‡ˆã•れるよã†ãªå¼•æ•°ã¯è¾žæ›¸ã§è¡¨ã™ã“ã¨ãŒã§ãã¾ã›ã‚“。ãã“ã§æ–‡å—列 'True' ã‚„ 'False' ã¨ã—ã¦æ¸¡ã—ã¦ã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—ト内㧠Boolean 値㮠True/False ã«å¤‰æ›ã™ã‚‹å¿…è¦ãŒã‚りã¾ã™ã€‚例ãˆã°ã€ä¿®æ£å¾Œã¯ã€hyperparameters ã¯ä»¥ä¸‹ã®ã‚ˆã†ã«æ¸¡ã—ã¾ã™ã€‚\n", "```python\n", "hyperparameters = {'save-model': 'True'}\n", "```\n", "\n", "ã“ã®å¤‰æ›´ã«ä¼´ã£ã¦ã€å¼•æ•°ã‚’å—ã‘å–るトレーニングスクリプトも修æ£ãŒå¿…è¦ã§ã™ã€‚具体的ã«ã¯ã€Boolean 値をå—ã‘å–るコードã¯\n", "\n", "```python\n", "parser.add_argument('--no-cuda', action='store_true', default=False,\n", " help='disables CUDA training')\n", "```\n", " \n", "ã®ã‚ˆã†ã« `action='store_true'` ãŒå…¥ã£ã¦ã„ã¾ã™ã®ã§ã€ã“ã“ã‚’ä¿®æ£ã—ã¾ã™ã€‚ä¿®æ£ã¯ `action='store_true'` ã‚’ `type=strtobool` ã¨ã—ã¦ã€ãƒ©ã‚¤ãƒ–ラリ㮠`strtobool` ã§æ–‡å—列ã‹ã‚‰ Boolean 値ã«å¤‰æ›ã—ã¾ã™ã€‚\n", "\n", "```python\n", "parser.add_argument('--no-cuda', type=strtobool, default=False,\n", " help='disables CUDA training')\n", "```\n", "\n", "**main() ã®æœ€åˆã§ `from distutils.util import strtobool` ã‚’ã™ã‚‹ã®ã‚’忘れãªã„よã†ã«ã—ã¾ã—ょã†ã€‚**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-3. å¦ç¿’データã®ãƒãƒ¼ãƒ‰\n", "\n", "å…ƒã®ã‚³ãƒ¼ãƒ‰ã§ã¯ `datasets.MNIST` を利用ã—ã¦ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ãƒ»èªã¿è¾¼ã¿ã‚’行ã£ã¦ã„ã¾ã™ã€‚具体的ã«ã¯ã€`main(unused_argv)`ã®ãªã‹ã«ã‚る以下ã®6行ã§ã™ã€‚今回ã¯S3ã‹ã‚‰ãƒ‡ãƒ¼ã‚¿ã‚’ダウンãƒãƒ¼ãƒ‰ã™ã‚‹ãŸã‚ã€ã“れらã®ã‚³ãƒ¼ãƒ‰ã¯ä¸è¦ã§ã™ã€‚**ã“ã“ã§å‰Šé™¤ã—ã¾ã—ょã†**。\n", "```python\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", " dataset1 = datasets.MNIST('../data', train=True, download=True,\n", " transform=transform)\n", " dataset2 = datasets.MNIST('../data', train=False,\n", " transform=transform)\n", " train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n", " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n", "```\n", "\n", "代ã‚りã«S3ã‹ã‚‰ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã—ãŸãƒ‡ãƒ¼ã‚¿ã‚’èªã¿è¾¼ã¿ã‚³ãƒ¼ãƒ‰ã‚’実装ã—ã¾ã—ょã†ã€‚環境変数ã‹ã‚‰å–å¾—ã—㟠`train_dir`ã‚„`test_dir` ã«ãƒ‡ãƒ¼ã‚¿ã‚’ä¿å˜ã—ãŸãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã¸ã®ãƒ‘スãŒä¿å˜ã•れã€ãれãžã‚Œ `/opt/ml/input/data/train`, `/opt/ml/input/data/test` ã¨ãªã‚Šã¾ã™ã€‚詳細㯠[ドã‚ュメント](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-trainingdata) ã‚’ã”覧下ã•ã„。デフォルト㮠FILE Mode ã§ã¯ã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚³ãƒ³ãƒ†ãƒŠèµ·å‹•時㫠S3 ã‹ã‚‰ã“れらã®ãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã¸ãƒ‡ãƒ¼ã‚¿ãŒã‚³ãƒ”ーã•れã€PIPE モードを指定ã™ã‚‹ã¨éžåŒæœŸã«ãƒ•ァイルãŒã‚³ãƒ”ーã•れã¾ã™ã€‚\n", "\n", "今回㯠npy ã®ãƒ•ァイルをèªã‚€ã‚ˆã†ã«ã‚³ãƒ¼ãƒ‰ã‚’æ›¸ãæ›ãˆã‚Œã°è‰¯ã„ã®ã§ã€ä»¥ä¸‹ã®ã‚ˆã†ãªã‚³ãƒ¼ãƒ‰ã‚’追記ã—ã¾ã™ã€‚パス㌠`train_dir`, `test_dir` ã«ä¿å˜ã•れã¦ã„ã‚‹ã“ã¨ã‚’ã†ã¾ã利用ã—ã¾ã—ょã†ã€‚ã‚‚ã¨ã® npy ã®ãƒ‡ãƒ¼ã‚¿ã‚¿ã‚¤ãƒ—㯠uint8 ã§ã™ãŒã€ç”»åƒã®å€¤ã‚’ 0 ã‹ã‚‰ 1 ã®ç¯„囲内ã«ãªã‚‹ã‚ˆã†ã«ã—ã¾ã™ã€‚\n", "```python\n", "import numpy as np\n", "train_image = torch.from_numpy(np.load(os.path.join(train_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", "train_image = torch.unsqueeze(train_image, 1)\n", "train_label = torch.from_numpy(np.load(os.path.join(train_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", "train_dataset = torch.utils.data.TensorDataset(train_image, train_label)\n", "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n", "\n", "test_image = torch.from_numpy(np.load(os.path.join(test_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", "test_image = torch.unsqueeze(test_image, 1)\n", "test_label = torch.from_numpy(np.load(os.path.join(test_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", "test_dataset = torch.utils.data.TensorDataset(test_image, test_label)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)\n", "```\n", "\n", "#### 確èª\n", "\n", "ã“ã“ã¾ã§ã®ä¿®æ£ã§ `main()` ã®å†’é ã®å®Ÿè£…ãŒä»¥ä¸‹ã®æ§˜ã«ãªã£ã¦ã„ã‚‹ã“ã¨ã‚’確èªã—ã¾ã—ょã†ã€‚\n", "\n", "```python\n", "def main():\n", " import os\n", " from distutils.util import strtobool\n", " train_dir = os.environ['SM_CHANNEL_TRAIN']\n", " test_dir = os.environ['SM_CHANNEL_TEST']\n", " model_dir = os.environ['SM_MODEL_DIR']\n", " # Training settings\n", " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", " parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", " help='input batch size for training (default: 64)')\n", " parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", " help='input batch size for testing (default: 1000)')\n", " parser.add_argument('--epochs', type=int, default=14, metavar='N',\n", " help='number of epochs to train (default: 14)')\n", " parser.add_argument('--lr', type=float, default=1.0, metavar='LR',\n", " help='learning rate (default: 1.0)')\n", " parser.add_argument('--gamma', type=float, default=0.7, metavar='M',\n", " help='Learning rate step gamma (default: 0.7)')\n", " parser.add_argument('--no-cuda', type=strtobool, default=False,\n", " help='disables CUDA training')\n", " parser.add_argument('--dry-run', type=strtobool, default=False,\n", " help='quickly check a single pass')\n", " parser.add_argument('--seed', type=int, default=1, metavar='S',\n", " help='random seed (default: 1)')\n", " parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n", " help='how many batches to wait before logging training status')\n", " parser.add_argument('--save-model', type=strtobool, default=False,\n", " help='For Saving the current Model')\n", " args = parser.parse_args()\n", " use_cuda = not args.no_cuda and torch.cuda.is_available()\n", "\n", " torch.manual_seed(args.seed)\n", "\n", " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", "\n", " train_kwargs = {'batch_size': args.batch_size}\n", " test_kwargs = {'batch_size': args.test_batch_size}\n", " \n", " if use_cuda:\n", " cuda_kwargs = {'num_workers': 1,\n", " 'pin_memory': True,\n", " 'shuffle': True}\n", " train_kwargs.update(cuda_kwargs)\n", " test_kwargs.update(cuda_kwargs)\n", "\n", " import numpy as np\n", " train_image = torch.from_numpy(np.load(os.path.join(train_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", " train_image = torch.unsqueeze(train_image, 1)\n", " train_label = torch.from_numpy(np.load(os.path.join(train_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", " train_dataset = torch.utils.data.TensorDataset(train_image, train_label)\n", " train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n", " \n", " test_image = torch.from_numpy(np.load(os.path.join(test_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", " test_image = torch.unsqueeze(test_image, 1)\n", " test_label = torch.from_numpy(np.load(os.path.join(test_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", " test_dataset = torch.utils.data.TensorDataset(test_image, test_label)\n", " test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-3. å¦ç¿’済ã¿ãƒ¢ãƒ‡ãƒ«ã®å‡ºåŠ›å…ˆã®å¤‰æ›´\n", "\n", "å¦ç¿’ãŒå®Œäº†ã™ã‚‹ã¨ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ãŒå‰Šé™¤ã•れã¦ã—ã¾ã„ã¾ã™ãŒã€`/opt/ml/model` ã«ã‚るファイル㯠model.tar.gz ã«åœ§ç¸®ã•れ S3 ã«ä¿å˜ã•れã¾ã™ã€‚ã“ã“ã«ã€ãƒ¢ãƒ‡ãƒ« `mnist_cnn.pt` ã‚’ä¿å˜ã—ã¦å¦ç¿’を終了ã—ã¾ã™ã€‚パス `/opt/ml/model` ã¯ç’°å¢ƒå¤‰æ•°ã‹ã‚‰èªã¿è¾¼ã‚“ã§ã€å¤‰æ•° `model_dir` ã«ä¿å˜ã—ã¦ã„ã‚‹ã®ã§ã€ãれを使ã£ã¦ä¿å˜å…ˆã‚’指定ã—ã¾ã™ã€‚\n", "\n", "\n", "以下ã®ãƒ¢ãƒ‡ãƒ«ä¿å˜ã®ã‚³ãƒ¼ãƒ‰ã‚’\n", "```python\n", " if args.save_model:\n", " torch.save(model.state_dict(), \"mnist_cnn.pt\")\n", "```\n", "\n", "以下ã®ã‚ˆã†ã«æ›¸ãæ›ãˆã¾ã™ã€‚\n", "```python\n", " if args.save_model:\n", " torch.save(model.state_dict(), os.path.join(model_dir,\"mnist_cnn.pt\"))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Notebook 上ã§ã®ãƒ‡ãƒ¼ã‚¿æº–å‚™\n", "\n", "ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¹ã‚¯ãƒªãƒ—ãƒˆã®æ›¸ãæ›ãˆã¯çµ‚了ã—ã¾ã—ãŸã€‚ å¦ç¿’ã‚’å§‹ã‚ã‚‹å‰ã«ã€äºˆã‚ Amazon S3 ã«ãƒ‡ãƒ¼ã‚¿ã‚’準備ã—ã¦ãŠãå¿…è¦ãŒã‚りã¾ã™ã€‚ã“ã® Notebook を使ã£ã¦ãã®ä½œæ¥ã‚’ã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import boto3\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "\n", "sagemaker_session = sagemaker.Session()\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "機械å¦ç¿’ã«åˆ©ç”¨ã™ã‚‹æ‰‹æ›¸ãæ•°å—データセット㮠MNIST を利用ã—ã¾ã™ã€‚`keras.datasets`を利用ã—ã¦ãƒ‡ãƒ¼ã‚¿ã‚»ãƒƒãƒˆã‚’ダウンãƒãƒ¼ãƒ‰ã—ã€ãれãžã‚Œ npy å½¢å¼ã§ä¿å˜ã—ã¾ã™ã€‚dataset ã®ãƒ†ã‚¹ãƒˆãƒ‡ãƒ¼ã‚¿ `(X_test, y_test)` ã¯ã•らã«ãƒãƒªãƒ‡ãƒ¼ã‚·ãƒ§ãƒ³ãƒ‡ãƒ¼ã‚¿ã¨ãƒ†ã‚¹ãƒˆãƒ‡ãƒ¼ã‚¿ã«åˆ†å‰²ã—ã¾ã™ã€‚å¦ç¿’データ `X_train, y_train` ã¨ãƒãƒªãƒ‡ãƒ¼ã‚·ãƒ§ãƒ³ãƒ‡ãƒ¼ã‚¿ `X_valid, y_valid` ã®ã¿ã‚’å¦ç¿’ã«åˆ©ç”¨ã™ã‚‹ãŸã‚ã€ã“れらを npy å½¢å¼ã§ã¾ãšã¯ä¿å˜ã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, json\n", "NOTEBOOK_METADATA_FILE = \"/opt/ml/metadata/resource-metadata.json\"\n", "if os.path.exists(NOTEBOOK_METADATA_FILE):\n", " with open(NOTEBOOK_METADATA_FILE, \"rb\") as f:\n", " metadata = json.loads(f.read())\n", " domain_id = metadata.get(\"DomainId\")\n", " on_studio = True if domain_id is not None else False\n", "print(\"Is this notebook runnning on Studio?: {}\".format(on_studio))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python -m pip install -U scikit-image\n", "!aws s3 cp s3://fast-ai-imageclas/mnist_png.tgz . --no-sign-request\n", "if on_studio:\n", " !tar -xzf mnist_png.tgz -C /opt/ml --no-same-owner\n", "else:\n", " !tar -xvzf mnist_png.tgz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage.io import ImageCollection,concatenate_images\n", "from PIL import Image\n", "import numpy as np\n", "import pathlib\n", "\n", "def load_image_with_label(f):\n", " label = pathlib.PurePath(f).parent.name\n", " return np.array(Image.open(f)), label\n", "if on_studio:\n", " dataset = ImageCollection(\"/opt/ml/mnist_png/*/*/*.png\", load_func=load_image_with_label)\n", "else:\n", " dataset = ImageCollection(\"./mnist_png/*/*/*.png\", load_func=load_image_with_label)\n", "np_dataset = np.array(dataset, dtype=\"object\")\n", "X = concatenate_images(np_dataset[:,0])\n", "y = np_dataset[:,1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "index = np.random.permutation(70000)\n", "X = X[index]\n", "y = y[index]\n", "\n", "X_train = X[0:50000,0:784]\n", "X_valid = X[50000:60000,0:784]\n", "X_test = X[60000:70000,0:784]\n", "y_train = y[0:50000]\n", "y_valid = y[50000:60000]\n", "y_test = y[60000:70000]\n", "\n", "os.makedirs('data/train', exist_ok=True)\n", "os.makedirs('data/valid', exist_ok=True)\n", "np.save('data/train/image.npy', X_train)\n", "np.save('data/train/label.npy', y_train)\n", "np.save('data/valid/image.npy', X_test)\n", "np.save('data/valid/label.npy', y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ã“れを Amazon S3 ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = sagemaker_session.upload_data(path='data/train', key_prefix='data/mnist-npy/train')\n", "valid_data = sagemaker_session.upload_data(path='data/valid', key_prefix='data/mnist-npy/valid')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. トレーニングã®å®Ÿè¡Œ\n", "\n", "`from sagemaker.pytorch import PyTorch` ã§èªã¿è¾¼ã‚“ã SageMaker Python SDK ã® PyTorch Estimator を作りã¾ã™ã€‚\n", "\n", "ã“ã“ã§ã¯ã€å¦ç¿’ã«åˆ©ç”¨ã™ã‚‹ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹æ•° `instance_count` ã‚„ インスタンスタイプ `instance_type` を指定ã—ã¾ã™ã€‚\n", "Docker を実行å¯èƒ½ãªç’°å¢ƒã§ã‚れã°ã€`instance_type = \"local\"` ã¨æŒ‡å®šã™ã‚‹ã¨ã€è¿½åŠ ã®ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ã‚’èµ·å‹•ã™ã‚‹ã“ã¨ãªãã€ã„ã¾ã€ã“ã®ãƒŽãƒ¼ãƒˆãƒ–ックを実行ã—ã¦ã„る環境ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚’実行ã§ãã¾ã™ã€‚インスタンス起動を待ã¤å¿…è¦ãŒãªã„ãŸã‚デãƒãƒƒã‚°ã«ä¾¿åˆ©ã§ã™ã€‚\n", "\n", "hyperparameters ã§æŒ‡å®šã—ãŸå†…容をトレーニングスクリプトã«å¼•æ•°ã¨ã—ã¦æ¸¡ã™ã“ã¨ãŒã§ãã¾ã™ã®ã§ã€`hyperparameters = {\"epoch\": 3}` ã¨ã—㦠3 エãƒãƒƒã‚¯ã ã‘実行ã—ã¦ã¿ã¾ã—ょã†ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "\n", "instance_type = \"ml.m4.xlarge\"\n", "\n", "mnist_estimator = PyTorch(entry_point='main.py',\n", " role=role,\n", " instance_count=1,\n", " instance_type=instance_type,\n", " framework_version='1.8.1',\n", " py_version='py3',\n", " hyperparameters = {\"epoch\": 3, \n", " \"save-model\": \"True\"})\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`estimator.fit` ã«ã‚ˆã‚Šãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚’é–‹å§‹ã—ã¾ã™ãŒã€ã“ã“ã§æŒ‡å®šã™ã‚‹ã€Œãƒãƒ£ãƒãƒ«ã€ã«ã‚ˆã£ã¦ã€ç’°å¢ƒå¤‰æ•°å `SM_CHANNEL_XXXX` ãŒæ±ºå®šã•れã¾ã™ã€‚ã“ã®ä¾‹ã®å ´åˆã€`'train', 'test'` を指定ã—ã¦ã„ã‚‹ã®ã§ã€`SM_CHANNEL_TRAIN`, `SM_CHANNEL_TEST` ã¨ãªã‚Šã¾ã™ã€‚トレーニングスクリプトã§ç’°å¢ƒå¤‰æ•°ã‚’å‚ç…§ã—ã¦ã„ã‚‹å ´åˆã¯ã€fit å†…ã®æŒ‡å®šã¨ä¸€è‡´ã—ã¦ã„ã‚‹ã“ã¨ã‚’確èªã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "mnist_estimator.fit({'train': train_data, 'test': valid_data})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`main.py` ã®ä¸ã§æ›¸ãæ›ãˆã«é–“é•ã„ãŒã‚ã£ãŸã‚‰ã€ã“ã“ã§ã‚¨ãƒ©ãƒ¼ã¨ãªã‚‹å ´åˆãŒã‚りã¾ã™ã€‚\n", "\n", " `===== Job Complete =====`\n", "ã¨è¡¨ç¤ºã•ã‚Œã‚Œã°æˆåŠŸã§ã™ã€‚" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### å¦ç¿’済ã¿ãƒ¢ãƒ‡ãƒ«ã®ç¢ºèª\n", "\n", "Amazon S3 ã«ä¿å˜ã•れãŸãƒ¢ãƒ‡ãƒ«ã¯æ™®é€šã«ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã—ã¦ä½¿ã†ã“ã¨ã‚‚ã§ãã¾ã™ã€‚ä¿å˜å…ˆã¯ `estimator.model_data` ã§ç¢ºèªã§ãã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist_estimator.model_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 推論スクリプトã®ä½œæˆ" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "作æˆã—ãŸãƒ¢ãƒ‡ãƒ«ã¯ SageMaker ã§ãƒ›ã‚¹ãƒ†ã‚£ãƒ³ã‚°ã™ã‚‹ã“ã¨ãŒã§ãã¾ã™ã€‚ãã†ã™ã‚‹ã¨ã€ã‚¯ãƒ©ã‚¤ã‚¢ãƒ³ãƒˆã‹ã‚‰æŽ¨è«–リクエストをå—ã‘å–ã£ã¦ã€æŽ¨è«–çµæžœã‚’è¿”ã™ã“ã¨ãŒå¯èƒ½ã«ãªã‚Šã¾ã™ã€‚\n", "\n", "ホスティングã™ã‚‹éš›ã«ã¯ã€(1) 作æˆã—ãŸãƒ¢ãƒ‡ãƒ«ã‚’èªã¿è¾¼ã‚“ã§ã€(2)推論を実行ã™ã‚‹ã‚¹ã‚¯ãƒªãƒ—トãŒå¿…è¦ã§ã€ãれãžã‚Œ `model_fn` 㨠`transform_fn` ã¨ã„ã†é–¢æ•°ã§å®Ÿè£…ã—ã¾ã™ã€‚ãれ以外ã®é–¢æ•°ã®å®Ÿè£…ã¯ä¸è¦ã§ã™ã€‚\n", "\n", "1. model_fn(model_dir) \n", " `modle_dir` ã«å¦ç¿’ã—ãŸãƒ¢ãƒ‡ãƒ«ãŒå±•é–‹ã•れã¦ã„る状態㧠`model_fn` ãŒå®Ÿè¡Œã•れã¾ã™ã€‚通常ã€ãƒ¢ãƒ‡ãƒ«ã‚’èªã¿è¾¼ã‚“ã§ã€return ã™ã‚‹ã‚³ãƒ¼ãƒ‰ã®ã¿ã‚’実装ã—ã¾ã™ã€‚PyTorch ã¯ãƒ¢ãƒ‡ãƒ«ã®ãƒ‘ラメータã®ã¿ã‚’ä¿å˜ã—ã¦åˆ©ç”¨ã™ã‚‹ã®ãŒä¸€èˆ¬çš„ã§ã€ã‚·ãƒ³ãƒœãƒ«ãƒ»ã‚°ãƒ©ãƒ•ã®å†…å®¹ã¯æŽ¨è«–ã‚³ãƒ¼ãƒ‰å†…ã§å®šç¾©ã™ã‚‹å¿…è¦ãŒã‚りã¾ã™ã€‚\n", "\n", "```python \n", "from io import BytesIO\n", "import json\n", "import numpy as np\n", "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", " self.dropout1 = nn.Dropout(0.25)\n", " self.dropout2 = nn.Dropout(0.5)\n", " self.fc1 = nn.Linear(9216, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.dropout2(x)\n", " x = self.fc2(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", "\n", "def model_fn(model_dir):\n", " model = Net()\n", " with open(os.path.join(model_dir, \"mnist_cnn.pt\"), \"rb\") as f:\n", " model.load_state_dict(torch.load(f))\n", " model.eval() # for inference\n", " return model\n", "```\n", " \n", " 複数ã®ãƒ¢ãƒ‡ãƒ«ã‚’èªã¿è¾¼ã‚€å ´åˆã‚„ NLP ã®ã‚ˆã†ã«èªžå½™ãƒ•ァイルもèªã¿è¾¼ã‚€å ´åˆã¯ã€ãれらをèªã¿è¾¼ã‚“ã§ dict å½¢å¼ãªã©ã§ return ã—ã¾ã™ã€‚return ã—ãŸå†…容㌠`transform_fn(model, request_body, request_content_type, response_content_type)` ã® `model` ã«å¼•ãç¶™ãŒã‚Œã¾ã™ã€‚\n", "\n", "2. transform_fn(model, request_body, request_content_type, response_content_type) \n", " èªã¿è¾¼ã‚“ã model ã«æŽ¨è«–ãƒªã‚¯ã‚¨ã‚¹ãƒˆ (request_body) を渡ã—ã¦ã€æŽ¨è«–çµæžœã‚’ return ã™ã‚‹ã‚ˆã†ãªã‚³ãƒ¼ãƒ‰ã‚’書ãã¾ã™ã€‚例ãˆã°ã€æŽ¨è«–リクエストã®å½¢å¼ãŒã„ãã¤ã‹ã‚ã£ã¦ã€ãれã«åŸºã¥ã„㦠request_body ã«å¯¾ã™ã‚‹å‰å‡¦ç†ã‚’変ãˆãŸã„å ´åˆã¯ã€ã‚¯ãƒ©ã‚¤ã‚¢ãƒ³ãƒˆã«content_type を指定ã•ã›ã€ãれをrequest_content_type ã¨ã—ã¦å—ã‘å–ã£ã¦æ¡ä»¶åˆ†å²ã§å®Ÿè£…ã—ã¾ã™ã€‚\n", " \n", " request_body 㯠byte å½¢å¼ã§å±Šãã¾ã™ã€‚ã“れをクライアントãŒé€ä»˜ã—ãŸå½¢å¼ã«åˆã‚ã›ã¦èªã¿è¾¼ã¿ã¾ã™ã€‚例ãˆã°ã€numpy å½¢å¼ã§é€ã‚‰ã‚ŒãŸã‚‚ã®ã§ã‚れã°ã€`np.load(BytesIO(request_body))`ã®ã‚ˆã†ã«ã—㦠numpy å½¢å¼ã§èªã¿è¾¼ã¿ã¾ã™ã€‚PyTorch ã®å ´åˆã ã¨ã€Torch Tensor ã®å½¢å¼ã«ã—ã¦æŽ¨è«–ã™ã‚‹ã“ã¨ãŒå¤šã„ã¨æ€ã„ã¾ã™ã®ã§ã€ãã®ã‚ˆã†ãªå®Ÿè£…を行ã£ã¦æŽ¨è«–çµæžœã‚’ return ã—ã¾ã™ã€‚å¿…è¦ã«å¿œã˜ã¦ response_content_type ã§æŒ‡å®šã—ãŸå½¢å¼ã§ return ã™ã‚‹ã¨ã€ã‚¯ãƒ©ã‚¤ã‚¢ãƒ³ãƒˆå´ã§çµæžœã®ä½¿ã„分ã‘ãŒã§ããŸã‚Šã—ã¾ã™ã€‚\n", " \n", " 今回㯠numpy ã§å—ã‘å–ã£ã¦çµæžœã‚’json ã§è¿”ã™ã‚ˆã†ã«ã—ã¾ã™ã€‚ \n", " \n", "```python\n", "def transform_fn(model, request_body, request_content_type, response_content_type):\n", " input_data = np.load(BytesIO(request_body))/255\n", " input_data = torch.from_numpy(input_data)\n", " input_data = torch.unsqueeze(input_data, 1)\n", " prediction = model(input_data)\n", " return json.dumps(prediction.tolist())\n", "```\n", " \n", "以上ã®ã‚³ãƒ¼ãƒ‰ã‚’ `deploy.py` ã«ã¾ã¨ã‚ã¦ä½œæˆã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "mnist_model=PyTorchModel(model_data=mnist_estimator.model_data, \n", " role=role, \n", " entry_point='deploy.py', \n", " framework_version='1.8.1',\n", " py_version='py3')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor=mnist_model.deploy(instance_type='ml.m4.xlarge', initial_instance_count=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "テストデータセットã‹ã‚‰ãƒ©ãƒ³ãƒ€ãƒ ã«10æžšé¸ã‚“ã§ãƒ†ã‚¹ãƒˆã‚’行ã„ã¾ã™ã€‚PyTorch ã® SageMaker Predictor 㯠numpy å½¢å¼ã‚’想定ã—ã¦ã„ã‚‹ã®ã§ã€JSON å½¢å¼ã‚’å—ã‘å–ã‚‹å ´åˆã¯ã€`JSONDeserializer()` を指定ã—ã¾ã—ょã†ã€‚10æžšã®ç”»åƒã«å¯¾ã™ã‚‹çµæžœã‚’表示ã—ã¾ã™ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.special import softmax\n", "\n", "test_size = 10\n", "select_idx = np.random.choice(np.arange(y_test.shape[0]), test_size)\n", "test_sample = X_test[select_idx].reshape([test_size,28,28]).astype(np.float32)\n", "\n", "predictor.deserializer=sagemaker.deserializers.JSONDeserializer()\n", "result = predictor.predict(test_sample)\n", "\n", "result = softmax(np.array(result), axis=1)\n", "predict_class = np.argmax(result, axis=1)\n", "print(\"Predicted labels: {}\".format(predict_class))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ç”»åƒã®ç¢ºèª\n", "実際ã®ç”»åƒã‚’確èªã—ã¦ã¿ã¾ã—ょã†ã€‚" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "W = 10 # 横ã«ä¸¦ã¹ã‚‹å€‹æ•°\n", "H = 10 # 縦ã«ä¸¦ã¹ã‚‹å€‹æ•°\n", "fig = plt.figure(figsize=(H, W))\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1.0, hspace=0.05, wspace=0.05)\n", "for i in range(test_size):\n", " ax = fig.add_subplot(H, W, i + 1, xticks=[], yticks=[])\n", " ax.set_title(\"{} ({:.3f})\".format(predict_class[i], result[i][predict_class[i]]), color=\"green\")\n", " ax.imshow(test_sample[i].reshape((28, 28)), cmap='gray')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "推論エンドãƒã‚¤ãƒ³ãƒˆã¯ç«‹ã¦ã£ã±ãªã—ã«ã—ã¦ã„ã‚‹ã¨ãŠé‡‘ãŒã‹ã‹ã‚‹ã®ã§ã€ç¢ºèªãŒçµ‚ã‚ã£ãŸã‚‰å¿˜ã‚Œãªã„ã†ã¡ã«å‰Šé™¤ã—ã¦ãã ã•ã„。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. ã¾ã¨ã‚" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch を使ã£ãŸ Amazon SageMaker ã¸ã®ç§»è¡Œæ‰‹é †ã«ã¤ã„ã¦ç´¹ä»‹ã—ã¾ã—ãŸã€‚普段ãŠä½¿ã„ã®ãƒ¢ãƒ‡ãƒ«ã§ã‚‚åŒæ§˜ã®æ‰‹é †ã§ç§»è¡ŒãŒå¯èƒ½ã§ã™ã®ã§ãœã²è©¦ã—ã¦ã¿ã¦ãã ã•ã„。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (PyTorch 1.6 Python 3.6 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.6-cpu-py36-ubuntu16.04-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }