{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Amazon SageMaker - Bring Your Own Model \n", "## PyTorch 編\n", "\n", "ここでは [PyTorch](https://pytorch.org/) のサンプルコードをAmazon SageMaker 上で実行するための移行手順について説明します。SageMaker Python SDK で PyTorch を使うための説明は [SDK のドキュメント](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html) にも多くの情報があります。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. トレーニングスクリプトの書き換え\n", "\n", "### 書き換えが必要な理由\n", "Amazon SageMaker では、オブジェクトストレージ Amazon S3 をデータ保管に利用します。例えば、S3 上の学習データを指定すると、自動的に Amazon SageMaker の学習用インスタンスにデータがダウンロードされ、トレーニングスクリプトが実行されます。トレーニングスクリプトを実行した後に、指定したディレクトリにモデルを保存すると、自動的にモデルがS3にアップロードされます。\n", "\n", "トレーニングスクリプトを SageMaker に持ち込む場合は、以下の点を修正する必要があります。\n", "- 学習用インスタンスにダウンロードされた学習データのロード\n", "- 学習が完了したときのモデルの保存\n", "\n", "これらの修正は、トレーニングスクリプトを任意の環境に持ち込む際の修正と変わらないでしょう。例えば、自身のPCに持ち込む場合も、`/home/user/data` のようなディレクトリからデータを読み込んで、`/home/user/model` にモデルを保存したいと考えるかもしれません。同様のことを SageMaker で行う必要があります。\n", "\n", "### 書き換える前に保存先を決める\n", "\n", "このハンズオンでは、S3からダウンロードする学習データ・バリデーションデータと、S3にアップロードするモデルは、それぞれ以下のように学習用インスタンスに保存することにします。`/opt/ml/input/data/train/`といったパスに設定することは奇異に感じられるかもしれませんが、これらは環境変数から読み込んで使用することが可能なパスで、コーディングをシンプルにすることができます。[1-1. 環境変数の取得](#env)で読み込み方法を説明します。\n", "\n", "#### 学習データ\n", "- 画像: `/opt/ml/input/data/train/image.npy`\n", "- ラベル: `/opt/ml/input/data/train/label.npy`\n", "\n", "#### バリデーションデータ\n", "- 画像: `/opt/ml/input/data/test/image.npy`\n", "- ラベル: `/opt/ml/input/data/test/label.npy`\n", "\n", "#### モデル\n", "`/opt/ml/model` 以下にシンボルやパラメータを保存する\n", "\n", "### 書き換える箇所\n", "まず [サンプルのソースコード](https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/examples/tutorials/layers/cnn_mnist.py) を以下のコマンドでダウンロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://raw.githubusercontent.com/pytorch/examples/master/mnist/main.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ダウンロードされた `mnist.py` をファイルブラウザから見つけて開いて下さい (JupyterLab の場合は左右にファイルを並べると作業しやすいです)。あるいはお好きなエディターをお使い頂いても結構です。この`mnist.py`は、`def main()`のなかでトレーニングスクリプト内で以下の関数を呼び出し、S3以外からデータをダウンロードしています。\n", "\n", "```python\n", "dataset1 = datasets.MNIST('../data', train=True, download=True,\n", " transform=transform)\n", "dataset2 = datasets.MNIST('../data', train=False,\n", " transform=transform)\n", "```\n", "\n", "こういった方法も可能ですが、今回はS3から学習データをダウンロードして、前述したように`/opt/ml/input/data/train/`といったパスから読み出して使います。書き換える点は主に4点です:\n", "\n", "1. 環境変数の取得 \n", " SageMaker では、あらかじめ指定されたディレクトリにS3からデータがダウンロードされたり、作成したモデルを保存したりします。これらのパスを環境変数から読み込んで使用することが可能です。環境変数を読み込むことで、学習データの位置をトレーニングスクリプト内にハードコーディングする必要がありません。もちろんパスの変更は可能で、API経由で渡すこともできます。\n", " \n", "1. 引数の修正 \n", " SageMaker では学習を実行する API に hyperparameters という辞書形式の情報を渡すことができます。この情報はトレーニングスクリプトに対する引数として利用できます。例えば、\n", " ```\n", " hyperparameters = {'epoch': 100}\n", " ```\n", " と指定して `main.py` を学習する場合は、`python main.py --epoch 100` を実行することとほぼ等価です。ただし、辞書形式で表せない引数はそのままでは扱えないため、扱えるよう修正する必要があります。 \n", "1. 学習データのロード \n", " 環境変数を取得して学習データの保存先がわかれば、その保存先から学習データをロードするようにコードを書き換えましょう。\n", "\n", "1. 学習済みモデルの保存形式と出力先の変更 \n", " SageMaker は [PyTorch 用のモデルサーバ](https://github.com/aws/sagemaker-pytorch-inference-toolkit)の仕組みを利用してモデルをホストし、`.pth` または `.pt` の形式の PyTorch モデルを利用することができます。学習して得られたモデルは、正しい保存先に保存する必要があります。学習が完了すると学習用インスタンスは削除されますので、保存先を指定のディレクトリに変更して、モデルがS3にアップロードされるようにします。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-1. 環境変数の取得\n", "\n", "Amazon SageMaker で学習を行う際、学習に利用する Python スクリプト (今回の場合は PyTorch のスクリプト) を、ノートブックインスタンスとは異なる学習用インスタンスで実行します。その際、データ・モデルの入出力のパスは、 [こちら](https://sagemaker.readthedocs.io/en/stable/using_tf.html#preparing-a-script-mode-training-script) に記述されているように `SM_CHANNEL_XXXX` や `SM_MODEL_DIR` という環境変数を参照して知ることができます。\n", "\n", "![データのやりとり](../img/sagemaker-data-model.png)\n", "\n", "ここでは、学習データのパス `SM_CHANNEL_TRAIN`, テストデータのパス `SM_CHANNEL_TEST`, モデルの保存先のパス `SM_MODEL_DIR` の環境変数の値を取得します。`def main():`の直下に、環境変数を取得する以下のコードを追加します。\n", "\n", "```python\n", "def main():\n", " import os\n", " train_dir = os.environ['SM_CHANNEL_TRAIN']\n", " test_dir = os.environ['SM_CHANNEL_TEST']\n", " model_dir = os.environ['SM_MODEL_DIR']\n", "```\n", "\n", "これで学習データ・バリデーションデータ・モデルの保存先を取得することができました。次にこれらのファイルを実際に読み込む処理を実装します。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-2. 引数の修正\n", "\n", "辞書形式で表せない引数はSageMaker の学習実行時にわたすことはできません。例えば、`python main.py --save-model` とすると `save_model` が True として解釈されるような引数は辞書で表すことができません。そこで文字列 'True' や 'False' として渡して、トレーニングスクリプト内で Boolean 値の True/False に変換する必要があります。例えば、修正後は、hyperparameters は以下のように渡します。\n", "```python\n", "hyperparameters = {'save-model': 'True'}\n", "```\n", "\n", "この変更に伴って、引数を受け取るトレーニングスクリプトも修正が必要です。具体的には、Boolean 値を受け取るコードは\n", "\n", "```python\n", "parser.add_argument('--no-cuda', action='store_true', default=False,\n", " help='disables CUDA training')\n", "```\n", " \n", "のように `action='store_true'` が入っていますので、ここを修正します。修正は `action='store_true'` を `type=strtobool` として、ライブラリの `strtobool` で文字列から Boolean 値に変換します。\n", "\n", "```python\n", "parser.add_argument('--no-cuda', type=strtobool, default=False,\n", " help='disables CUDA training')\n", "```\n", "\n", "**main() の最初で `from distutils.util import strtobool` をするのを忘れないようにしましょう。**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-3. 学習データのロード\n", "\n", "元のコードでは `datasets.MNIST` を利用してダウンロード・読み込みを行っています。具体的には、`main(unused_argv)`のなかにある以下の6行です。今回はS3からデータをダウンロードするため、これらのコードは不要です。**ここで削除しましょう**。\n", "```python\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", " dataset1 = datasets.MNIST('../data', train=True, download=True,\n", " transform=transform)\n", " dataset2 = datasets.MNIST('../data', train=False,\n", " transform=transform)\n", " train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n", " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n", "```\n", "\n", "代わりにS3からダウンロードしたデータを読み込みコードを実装しましょう。環境変数から取得した `train_dir`や`test_dir` にデータを保存したディレクトリへのパスが保存され、それぞれ `/opt/ml/input/data/train`, `/opt/ml/input/data/test` となります。詳細は [ドキュメント](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-trainingdata) をご覧下さい。デフォルトの FILE Mode では、トレーニングコンテナ起動時に S3 からこれらのディレクトリへデータがコピーされ、PIPE モードを指定すると非同期にファイルがコピーされます。\n", "\n", "今回は npy のファイルを読むようにコードを書き換えれば良いので、以下のようなコードを追記します。パスが `train_dir`, `test_dir` に保存されていることをうまく利用しましょう。もとの npy のデータタイプは uint8 ですが、画像の値を 0 から 1 の範囲内になるようにします。\n", "```python\n", "import numpy as np\n", "train_image = torch.from_numpy(np.load(os.path.join(train_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", "train_image = torch.unsqueeze(train_image, 1)\n", "train_label = torch.from_numpy(np.load(os.path.join(train_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", "train_dataset = torch.utils.data.TensorDataset(train_image, train_label)\n", "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n", "\n", "test_image = torch.from_numpy(np.load(os.path.join(test_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", "test_image = torch.unsqueeze(test_image, 1)\n", "test_label = torch.from_numpy(np.load(os.path.join(test_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", "test_dataset = torch.utils.data.TensorDataset(test_image, test_label)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)\n", "```\n", "\n", "#### 確認\n", "\n", "ここまでの修正で `main()` の冒頭の実装が以下の様になっていることを確認しましょう。\n", "\n", "```python\n", "def main():\n", " import os\n", " from distutils.util import strtobool\n", " train_dir = os.environ['SM_CHANNEL_TRAIN']\n", " test_dir = os.environ['SM_CHANNEL_TEST']\n", " model_dir = os.environ['SM_MODEL_DIR']\n", " # Training settings\n", " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", " parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", " help='input batch size for training (default: 64)')\n", " parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", " help='input batch size for testing (default: 1000)')\n", " parser.add_argument('--epochs', type=int, default=14, metavar='N',\n", " help='number of epochs to train (default: 14)')\n", " parser.add_argument('--lr', type=float, default=1.0, metavar='LR',\n", " help='learning rate (default: 1.0)')\n", " parser.add_argument('--gamma', type=float, default=0.7, metavar='M',\n", " help='Learning rate step gamma (default: 0.7)')\n", " parser.add_argument('--no-cuda', type=strtobool, default=False,\n", " help='disables CUDA training')\n", " parser.add_argument('--dry-run', type=strtobool, default=False,\n", " help='quickly check a single pass')\n", " parser.add_argument('--seed', type=int, default=1, metavar='S',\n", " help='random seed (default: 1)')\n", " parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n", " help='how many batches to wait before logging training status')\n", " parser.add_argument('--save-model', type=strtobool, default=False,\n", " help='For Saving the current Model')\n", " args = parser.parse_args()\n", " use_cuda = not args.no_cuda and torch.cuda.is_available()\n", "\n", " torch.manual_seed(args.seed)\n", "\n", " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", "\n", " train_kwargs = {'batch_size': args.batch_size}\n", " test_kwargs = {'batch_size': args.test_batch_size}\n", " \n", " if use_cuda:\n", " cuda_kwargs = {'num_workers': 1,\n", " 'pin_memory': True,\n", " 'shuffle': True}\n", " train_kwargs.update(cuda_kwargs)\n", " test_kwargs.update(cuda_kwargs)\n", "\n", " import numpy as np\n", " train_image = torch.from_numpy(np.load(os.path.join(train_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", " train_image = torch.unsqueeze(train_image, 1)\n", " train_label = torch.from_numpy(np.load(os.path.join(train_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", " train_dataset = torch.utils.data.TensorDataset(train_image, train_label)\n", " train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n", " \n", " test_image = torch.from_numpy(np.load(os.path.join(test_dir, 'image.npy'), allow_pickle=True).astype(np.float32))/255\n", " test_image = torch.unsqueeze(test_image, 1)\n", " test_label = torch.from_numpy(np.load(os.path.join(test_dir, 'label.npy'), allow_pickle=True).astype(np.long))\n", " test_dataset = torch.utils.data.TensorDataset(test_image, test_label)\n", " test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1-3. 学習済みモデルの出力先の変更\n", "\n", "学習が完了するとインスタンスが削除されてしまいますが、`/opt/ml/model` にあるファイルは model.tar.gz に圧縮され S3 に保存されます。ここに、モデル `mnist_cnn.pt` を保存して学習を終了します。パス `/opt/ml/model` は環境変数から読み込んで、変数 `model_dir` に保存しているので、それを使って保存先を指定します。\n", "\n", "\n", "以下のモデル保存のコードを\n", "```python\n", " if args.save_model:\n", " torch.save(model.state_dict(), \"mnist_cnn.pt\")\n", "```\n", "\n", "以下のように書き換えます。\n", "```python\n", " if args.save_model:\n", " torch.save(model.state_dict(), os.path.join(model_dir,\"mnist_cnn.pt\"))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Notebook 上でのデータ準備\n", "\n", "トレーニングスクリプトの書き換えは終了しました。 学習を始める前に、予め Amazon S3 にデータを準備しておく必要があります。この Notebook を使ってその作業をします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import boto3\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "\n", "sagemaker_session = sagemaker.Session()\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "機械学習に利用する手書き数字データセットの MNIST を利用します。`keras.datasets`を利用してデータセットをダウンロードし、それぞれ npy 形式で保存します。dataset のテストデータ `(X_test, y_test)` はさらにバリデーションデータとテストデータに分割します。学習データ `X_train, y_train` とバリデーションデータ `X_valid, y_valid` のみを学習に利用するため、これらを npy 形式でまずは保存します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, json\n", "NOTEBOOK_METADATA_FILE = \"/opt/ml/metadata/resource-metadata.json\"\n", "if os.path.exists(NOTEBOOK_METADATA_FILE):\n", " with open(NOTEBOOK_METADATA_FILE, \"rb\") as f:\n", " metadata = json.loads(f.read())\n", " domain_id = metadata.get(\"DomainId\")\n", " on_studio = True if domain_id is not None else False\n", "print(\"Is this notebook runnning on Studio?: {}\".format(on_studio))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python -m pip install -U scikit-image\n", "!aws s3 cp s3://fast-ai-imageclas/mnist_png.tgz . --no-sign-request\n", "if on_studio:\n", " !tar -xzf mnist_png.tgz -C /opt/ml --no-same-owner\n", "else:\n", " !tar -xvzf mnist_png.tgz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage.io import ImageCollection,concatenate_images\n", "from PIL import Image\n", "import numpy as np\n", "import pathlib\n", "\n", "def load_image_with_label(f):\n", " label = pathlib.PurePath(f).parent.name\n", " return np.array(Image.open(f)), label\n", "if on_studio:\n", " dataset = ImageCollection(\"/opt/ml/mnist_png/*/*/*.png\", load_func=load_image_with_label)\n", "else:\n", " dataset = ImageCollection(\"./mnist_png/*/*/*.png\", load_func=load_image_with_label)\n", "np_dataset = np.array(dataset, dtype=\"object\")\n", "X = concatenate_images(np_dataset[:,0])\n", "y = np_dataset[:,1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "index = np.random.permutation(70000)\n", "X = X[index]\n", "y = y[index]\n", "\n", "X_train = X[0:50000,0:784]\n", "X_valid = X[50000:60000,0:784]\n", "X_test = X[60000:70000,0:784]\n", "y_train = y[0:50000]\n", "y_valid = y[50000:60000]\n", "y_test = y[60000:70000]\n", "\n", "os.makedirs('data/train', exist_ok=True)\n", "os.makedirs('data/valid', exist_ok=True)\n", "np.save('data/train/image.npy', X_train)\n", "np.save('data/train/label.npy', y_train)\n", "np.save('data/valid/image.npy', X_test)\n", "np.save('data/valid/label.npy', y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これを Amazon S3 にアップロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = sagemaker_session.upload_data(path='data/train', key_prefix='data/mnist-npy/train')\n", "valid_data = sagemaker_session.upload_data(path='data/valid', key_prefix='data/mnist-npy/valid')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. トレーニングの実行\n", "\n", "`from sagemaker.pytorch import PyTorch` で読み込んだ SageMaker Python SDK の PyTorch Estimator を作ります。\n", "\n", "ここでは、学習に利用するインスタンス数 `instance_count` や インスタンスタイプ `instance_type` を指定します。\n", "Docker を実行可能な環境であれば、`instance_type = \"local\"` と指定すると、追加のインスタンスを起動することなく、いま、このノートブックを実行している環境でトレーニングを実行できます。インスタンス起動を待つ必要がないためデバッグに便利です。\n", "\n", "hyperparameters で指定した内容をトレーニングスクリプトに引数として渡すことができますので、`hyperparameters = {\"epoch\": 3}` として 3 エポックだけ実行してみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "\n", "instance_type = \"ml.m4.xlarge\"\n", "\n", "mnist_estimator = PyTorch(entry_point='main.py',\n", " role=role,\n", " instance_count=1,\n", " instance_type=instance_type,\n", " framework_version='1.8.1',\n", " py_version='py3',\n", " hyperparameters = {\"epoch\": 3, \n", " \"save-model\": \"True\"})\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`estimator.fit` によりトレーニングを開始しますが、ここで指定する「チャネル」によって、環境変数名 `SM_CHANNEL_XXXX` が決定されます。この例の場合、`'train', 'test'` を指定しているので、`SM_CHANNEL_TRAIN`, `SM_CHANNEL_TEST` となります。トレーニングスクリプトで環境変数を参照している場合は、fit 内の指定と一致していることを確認します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "mnist_estimator.fit({'train': train_data, 'test': valid_data})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`main.py` の中で書き換えに間違いがあったら、ここでエラーとなる場合があります。\n", "\n", " `===== Job Complete =====`\n", "と表示されれば成功です。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 学習済みモデルの確認\n", "\n", "Amazon S3 に保存されたモデルは普通にダウンロードして使うこともできます。保存先は `estimator.model_data` で確認できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist_estimator.model_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 推論スクリプトの作成" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "作成したモデルは SageMaker でホスティングすることができます。そうすると、クライアントから推論リクエストを受け取って、推論結果を返すことが可能になります。\n", "\n", "ホスティングする際には、(1) 作成したモデルを読み込んで、(2)推論を実行するスクリプトが必要で、それぞれ `model_fn` と `transform_fn` という関数で実装します。それ以外の関数の実装は不要です。\n", "\n", "1. model_fn(model_dir) \n", " `modle_dir` に学習したモデルが展開されている状態で `model_fn` が実行されます。通常、モデルを読み込んで、return するコードのみを実装します。PyTorch はモデルのパラメータのみを保存して利用するのが一般的で、シンボル・グラフの内容は推論コード内で定義する必要があります。\n", "\n", "```python \n", "from io import BytesIO\n", "import json\n", "import numpy as np\n", "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", " self.dropout1 = nn.Dropout(0.25)\n", " self.dropout2 = nn.Dropout(0.5)\n", " self.fc1 = nn.Linear(9216, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.dropout2(x)\n", " x = self.fc2(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", "\n", "def model_fn(model_dir):\n", " model = Net()\n", " with open(os.path.join(model_dir, \"mnist_cnn.pt\"), \"rb\") as f:\n", " model.load_state_dict(torch.load(f))\n", " model.eval() # for inference\n", " return model\n", "```\n", " \n", " 複数のモデルを読み込む場合や NLP のように語彙ファイルも読み込む場合は、それらを読み込んで dict 形式などで return します。return した内容が `transform_fn(model, request_body, request_content_type, response_content_type)` の `model` に引き継がれます。\n", "\n", "2. transform_fn(model, request_body, request_content_type, response_content_type) \n", " 読み込んだ model に推論リクエスト (request_body) を渡して、推論結果を return するようなコードを書きます。例えば、推論リクエストの形式がいくつかあって、それに基づいて request_body に対する前処理を変えたい場合は、クライアントにcontent_type を指定させ、それをrequest_content_type として受け取って条件分岐で実装します。\n", " \n", " request_body は byte 形式で届きます。これをクライアントが送付した形式に合わせて読み込みます。例えば、numpy 形式で送られたものであれば、`np.load(BytesIO(request_body))`のようにして numpy 形式で読み込みます。PyTorch の場合だと、Torch Tensor の形式にして推論することが多いと思いますので、そのような実装を行って推論結果を return します。必要に応じて response_content_type で指定した形式で return すると、クライアント側で結果の使い分けができたりします。\n", " \n", " 今回は numpy で受け取って結果をjson で返すようにします。 \n", " \n", "```python\n", "def transform_fn(model, request_body, request_content_type, response_content_type):\n", " input_data = np.load(BytesIO(request_body))/255\n", " input_data = torch.from_numpy(input_data)\n", " input_data = torch.unsqueeze(input_data, 1)\n", " prediction = model(input_data)\n", " return json.dumps(prediction.tolist())\n", "```\n", " \n", "以上のコードを `deploy.py` にまとめて作成します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "mnist_model=PyTorchModel(model_data=mnist_estimator.model_data, \n", " role=role, \n", " entry_point='deploy.py', \n", " framework_version='1.8.1',\n", " py_version='py3')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor=mnist_model.deploy(instance_type='ml.m4.xlarge', initial_instance_count=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "テストデータセットからランダムに10枚選んでテストを行います。PyTorch の SageMaker Predictor は numpy 形式を想定しているので、JSON 形式を受け取る場合は、`JSONDeserializer()` を指定しましょう。10枚の画像に対する結果を表示します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.special import softmax\n", "\n", "test_size = 10\n", "select_idx = np.random.choice(np.arange(y_test.shape[0]), test_size)\n", "test_sample = X_test[select_idx].reshape([test_size,28,28]).astype(np.float32)\n", "\n", "predictor.deserializer=sagemaker.deserializers.JSONDeserializer()\n", "result = predictor.predict(test_sample)\n", "\n", "result = softmax(np.array(result), axis=1)\n", "predict_class = np.argmax(result, axis=1)\n", "print(\"Predicted labels: {}\".format(predict_class))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 画像の確認\n", "実際の画像を確認してみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "W = 10 # 横に並べる個数\n", "H = 10 # 縦に並べる個数\n", "fig = plt.figure(figsize=(H, W))\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1.0, hspace=0.05, wspace=0.05)\n", "for i in range(test_size):\n", " ax = fig.add_subplot(H, W, i + 1, xticks=[], yticks=[])\n", " ax.set_title(\"{} ({:.3f})\".format(predict_class[i], result[i][predict_class[i]]), color=\"green\")\n", " ax.imshow(test_sample[i].reshape((28, 28)), cmap='gray')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "推論エンドポイントは立てっぱなしにしているとお金がかかるので、確認が終わったら忘れないうちに削除してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. まとめ" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch を使った Amazon SageMaker への移行手順について紹介しました。普段お使いのモデルでも同様の手順で移行が可能ですのでぜひ試してみてください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (PyTorch 1.6 Python 3.6 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.6-cpu-py36-ubuntu16.04-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }