{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch DeepLab の学習済みモデルを SageMaker でデプロイする\n", "\n", "PyTorch Hub で公開されている [DeepLab V3 のモデル](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/)をダウンロードしてデプロイします。このノートブックでは、モデルのSageMaker への持ち込み方法を知るため、以下のステップでデプロイします。\n", "\n", "1. PyTorch Hub からモデルをダウンロードし、S3 に保存します。\n", "1. ダウンロードしたモデルで推論を行うためのコードを作成します。\n", "1. S3 に保存したモデルを指定して、SageMaker にデプロイします。\n", "\n", "実際には、推論コードの中でPyTorch Hub からモデルをダウンロードできるため、1をスキップする方法も可能です。\n", "\n", "\n", "## 1. PyTorch Hub からのモデルダウンロード\n", "\n", "`torch.hub`でモデルをダウンロードし、パラメータの情報のみ保存します。保存したファイルは `tar.gz` の形式にして S3 にアップロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "os.makedirs('model',exist_ok=True)\n", "model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True, progress=False)\n", "path = './model/model.pth'\n", "torch.save(model.cpu().state_dict(), path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!tar cvzf model.tar.gz -C ./model ." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "sagemaker_session = sagemaker.Session()\n", "model_path = sagemaker_session.upload_data(\"model.tar.gz\", key_prefix =\"pytorch_deeplab_model\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 推論コードの作成\n", "\n", "アップロードしたモデルを読み込んで推論を実行するコードを作成します。モデルの読み込みは `model_fn` で、推論の実行は `transform_fn`で実装します。\n", "PyTorch ではモデルのパラメータ以外にシンボルの情報が必要なので、PyTorch Hub から呼び出して利用します。各関数の実装は[公式の利用方法](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/)を参考にしています。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile deploy.py\n", "\n", "from io import BytesIO\n", "import json\n", "import numpy as np\n", "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from PIL import Image\n", "from torchvision import transforms \n", "\n", "def model_fn(model_dir):\n", " model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=False, progress=False)\n", " with open(os.path.join(model_dir, \"model.pth\"), \"rb\") as f:\n", " model.load_state_dict(torch.load(f), strict=False)\n", " model.eval() # for inference\n", " return model\n", "\n", "def transform_fn(model, request_body, request_content_type, response_content_type):\n", " \n", " input_data = np.load(BytesIO(request_body))\n", " preprocess = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ])\n", " \n", " input_tensor = preprocess(input_data)\n", " input_batch = input_tensor.unsqueeze(0)\n", " prediction = model(input_batch)\n", " return json.dumps(prediction['out'].tolist())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. デプロイと推論" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "deeplab_model=PyTorchModel(model_data=model_path, \n", " role=sagemaker.get_execution_role(), \n", " entry_point='deploy.py', \n", " framework_version='1.8.1',\n", " py_version='py3')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor=deeplab_model.deploy(instance_type='ml.m4.xlarge', initial_instance_count=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://github.com/pytorch/hub/raw/master/images/dog.jpg" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "from torchvision import transforms\n", "from skimage.segmentation import mark_boundaries\n", "\n", "input_image = Image.open(\"dog.jpg\").convert('RGB')\n", "w, h = input_image.size\n", "\n", "input_image = input_image.resize((150, 100))\n", "np_input_image = np.array(input_image)\n", "predictor.deserializer = sagemaker.deserializers.JSONDeserializer()\n", "predictions = predictor.predict(np_input_image)\n", "\n", "# create a color pallette, selecting a color for each class\n", "palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])\n", "colors = torch.as_tensor([i for i in range(21)])[:, None] * palette\n", "colors = (colors % 255).numpy().astype(\"uint8\")\n", "label_map = np.array(predictions[0]).argmax(0)\n", "\n", "# # plot the semantic segmentation predictions of 21 classes in each color\n", "r = Image.fromarray(label_map.astype(np.uint8))\n", "r.putpalette(colors)\n", "r = Image.blend(r.convert('RGBA'), input_image.convert('RGBA'), 0.5) \n", "\n", "plt.rcParams['figure.figsize'] = [12, 8]\n", "plt.imshow(r)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最後に不要なエンドポイントを削除します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (PyTorch 1.6 Python 3.6 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.6-cpu-py36-ubuntu16.04-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }