{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Amazon SageMaker で Detectron2 と SKU-110K データセットを使って物体検出\n", "\n", "このノートブックでは、Amazon SageMaker で Detectron2 の物体検出モデルを fine-tuning する方法をご紹介します。このノートブックについては、[こちらの AWS blog](https://aws.amazon.com/jp/blogs/news/object-detection-with-detectron2-on-amazon-sagemaker/) で解説しています。\n", "\n", "ml.p3.8xlarge などハイスペックのインスタンスを使用するので、料金にご注意ください。インスタンスごとの料金は [こちらのサイト](https://aws.amazon.com/jp/sagemaker/pricing/) で確認できます。" ] }, { "cell_type": "markdown", "metadata": { "toc": true }, "source": [ "

Table of Contents

\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 背景\n", "\n", "Detectron2 は、物体検出アルゴリズムを実装したコンピュータビジョンフレームワークです。Facebook の AI リサーチチームによって開発されました。Detecton2 の祖先である Detectron は完全に Caffe で記述されていましたが、Detecton2 は PyTorch でリファクタリングされており、高速な実験とイテレーションを可能にします。Detectron2 は、物体検出、セマンティックセグメンテーション、ポーズ推定などの最先端のモデルを含む豊富な model zoo を備えています。モジュール式の設計により、Detetron2は容易に拡張可能であり、その結果、最先端の研究プロジェクトをその上に実装することができます。\n", "\n", "このノートブックでは、Detectron2 を用いて [SKU110k-dataset](https://github.com/eg4000/SKU110K_CVPR19) のモデルを学習・評価します。このオープンソースのデータセットには、小売店の棚の画像が含まれています。各画像には約150個のオブジェクトが含まれており、密集したシーンの物体検出アルゴリズムのテストに適しています。バウンディングボックスは、製品のカテゴリを区別することなく、SKU に関連付けられています。\n", "\n", "このノートブックでは、Detectron2 の model zoo から物体検出モデルを使用しています。そして、Amazon SageMaker MLプラットフォームを利用して、SKU110kデータセットで事前に学習されたモデルを fine-tuning し、学習されたモデルを推論用にデプロイします。\n", "\n", "**注意:このノートブックを Amazon SageMaker ノートブックインスタンスで使用する場合は、インスタンスにアタッチする EBS ボリュームサイズを 80GB くらいに設定するのがおすすめです。50GB 程度ですと、コンテナイメージをビルドする際に容量不足になることがあります。**\n", "\n", "ノートブックインスタンスを使用している場合、コンテナイメージビルドの際の容量不足を回避するために以下のセルのコメントアウトを外してから実行して docker 関連のファイルの保存場所を変更してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %%bash\n", "\n", "# sudo /etc/init.d/docker stop\n", "# sudo mv /var/lib/docker /home/ec2-user/SageMaker/docker\n", "# sudo ln -s /home/ec2-user/SageMaker/docker /var/lib/docker\n", "# sudo /etc/init.d/docker start" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## セットアップ\n", "\n", "**注意:Sagemaker Notebook インスタンスまたは Sagemaker Studio インスタンスを使用してこのノートブックを実行する場合は、`AmazonSageMakerFullAccess` と `AmazonEC2ContainerRegistryFullAccess` ポリシーが付与された IAMロールを使用していることを確認してください。足りないポリシーがあれば、以下の手順で追加してください。**\n", "\n", "1. [Amazon SageMaker console](https://console.aws.amazon.com/sagemaker/) を開く\n", "1. **ノートブックインスタンス** を開いて現在使用しているノートブックインスタンスを選択する\n", "1. **アクセス許可と暗号化** の部分に表示されている IAM ロールへのリンクをクリックする\n", "1. IAM ロールの ARN は後で使用するのでメモ帳などにコピーしておく\n", "1. **ポリシーをアタッチします** をクリックして `AmazonEC2ContainerRegistryFullAccess` を検索する\n", "1. `AmazonEC2ContainerRegistryFullAccess` の横のチェックボックスをオンにする\n", "1. 必要なポリシーの数だけ、同様の手順でポリシーを検索しチェックボックスをオンにして **ポリシーのアタッチ** をクリックする\n", "\n", "まず、必要なPythonライブラリをインポートし、いくつかの共通パラメタを設定します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "\n", "assert (\n", " sagemaker.__version__.split(\".\")[0] == \"2\"\n", "), \"Please upgrade SageMaker Python SDK to version 2\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bucket = sagemaker.session.Session().default_bucket()\n", "prefix_data = \"detectron2/data\"\n", "prefix_model = \"detectron2/training_artefacts\"\n", "prefix_code = \"detectron2/model\"\n", "prefix_predictions = \"detectron2/predictions\"\n", "local_folder = \"cache\" # cache folder used to store downloaded data - not versioned\n", "\n", "\n", "sm_session = sagemaker.Session(default_bucket=bucket)\n", "role = sagemaker.get_execution_role()\n", "region = sm_session.boto_region_name\n", "account = sm_session.account_id()\n", "\n", "# if bucket doesn't exist, create one\n", "s3_resource = boto3.resource(\"s3\")\n", "if not s3_resource.Bucket(bucket) in s3_resource.buckets.all():\n", " s3_resource.create_bucket(\n", " Bucket=bucket, CreateBucketConfiguration={\"LocationConstraint\": region}\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## データセットの準備\n", "\n", "SKU110K を学習用に用意するには、以下の作業が必要です。\n", "\n", "- SKU-110K のデータセットをダウンロードし、解凍する。\n", "- ファイル名の prefix に従って、画像を3つのチャンネル(学習、検証、テスト)に分割する。\n", "- PIL.Image.load()で読み込めないような、破損した画像(および対応するアノテーション)を削除する。\n", "- 画像チャンネルをS3バケットにアップロードする。\n", "- アノテーションデータを拡張マニフェストファイル形式に変換し、これらのファイルをS3にアップロードする。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import tarfile\n", "import tempfile\n", "from datetime import datetime\n", "from pathlib import Path\n", "from typing import Mapping, Optional, Sequence\n", "from urllib import request\n", "\n", "import boto3\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SKU-110K データセットのダウンロード\n", "\n", "解凍したデータセットの合計サイズは 12.2GBです。これに合わせて、ノートブックインスタンスのボリュームサイズを設定してください。30GB程度のボリュームサイズをおすすめします。\n", "\n", "⚠️ データセットのダウンロードと解凍には15〜20分程度かかります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! wget -P cache http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sku_dataset_dirname = \"SKU110K_fixed\"\n", "assert Path(\n", " local_folder\n", ").exists(), f\"Set wget directory-prefix to {local_folder} in the previous cell\"\n", "\n", "\n", "def track_progress(members):\n", " i = 0\n", " for member in members:\n", " if i % 100 == 0:\n", " print(\".\", end=\"\")\n", " i += 1\n", " yield member\n", "\n", "\n", "if not (Path(local_folder) / sku_dataset_dirname).exists():\n", " compressed_file = tarfile.open(\n", " name=os.path.join(local_folder, sku_dataset_dirname + \".tar.gz\")\n", " )\n", " compressed_file.extractall(\n", " path=local_folder, members=track_progress(compressed_file)\n", " )\n", "else:\n", " print(f\"Using the data in `{local_folder}` folder\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 画像の前処理" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path_images = Path(local_folder) / sku_dataset_dirname / \"images\"\n", "assert path_images.exists(), f\"{path_images} not found\"\n", "\n", "prefix_to_channel = {\n", " \"train\": \"training\",\n", " \"val\": \"validation\",\n", " \"test\": \"test\",\n", "}\n", "for channel_name in prefix_to_channel.values():\n", " if not (path_images.parent / channel_name).exists():\n", " (path_images.parent / channel_name).mkdir()\n", "\n", "for path_img in path_images.iterdir():\n", " for prefix in prefix_to_channel:\n", " if path_img.name.startswith(prefix):\n", " path_img.replace(\n", " path_images.parent / prefix_to_channel[prefix] / path_img.name\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Detectron2 は Pillow ライブラリを使って画像を読み込んでいます。SKU データセットに含まれる一部の画像が破損しており、データローダが IOError 例外を発生させることがわかりました。そこで、それらの画像をデータセットから削除します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CORRUPTED_IMAGES = {\n", " \"training\": (\"train_4222.jpg\", \"train_5822.jpg\", \"train_882.jpg\", \"train_924.jpg\"),\n", " \"validation\": tuple(),\n", " \"test\": (\"test_274.jpg\", \"test_2924.jpg\"),\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for channel_name in prefix_to_channel.values():\n", " for img_name in CORRUPTED_IMAGES[channel_name]:\n", " try:\n", " (path_images.parent / channel_name / img_name).unlink()\n", " print(f\"{img_name} removed from channel {channel_name} \")\n", " except FileNotFoundError:\n", " print(f\"{img_name} not in channel {channel_name}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for channel_name in prefix_to_channel.values():\n", " print(\n", " f\"Number of {channel_name} images = {sum(1 for x in (path_images.parent / channel_name).glob('*.jpg'))}\"\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "データセットを S3 にアップロードします。 ⚠️ この処理には 10-15 分程度かかります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "channel_to_s3_imgs = {}\n", "\n", "for channel_name in prefix_to_channel.values():\n", " inputs = sm_session.upload_data(\n", " path=str(path_images.parent / channel_name),\n", " bucket=bucket,\n", " key_prefix=f\"{prefix_data}/{channel_name}\",\n", " )\n", " print(f\"{channel_name} images uploaded to {inputs}\")\n", " channel_to_s3_imgs[channel_name] = inputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### アノテーションデータの前処理\n", "\n", "SKU-110K データセットのアノテーションは csv ファイルで保存されています。ここでは、それらを [拡張マニフェストファイル](https://docs.aws.amazon.com/sagemaker/latest/dg/augmented-manifest.html) に変換しています。バウンディングボックスアノテーションの仕様については、[SageMakerのドキュメント](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-data-output.html#sms-output-box) をご参照ください。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_annotation_channel(\n", " channel_id: str,\n", " path_to_annotation: Path,\n", " bucket_name: str,\n", " data_prefix: str,\n", " img_annotation_to_ignore: Optional[Sequence[str]] = None,\n", ") -> Sequence[Mapping]:\n", " r\"\"\"Change format from original to augmented manifest files\n", "\n", " Parameters\n", " ----------\n", " channel_id : str\n", " name of the channel, i.e. training, validation or test\n", " path_to_annotation : Path\n", " path to annotation file\n", " bucket_name : str\n", " bucket where the data are uploaded\n", " data_prefix : str\n", " bucket prefix\n", " img_annotation_to_ignore : Optional[Sequence[str]]\n", " annotation from these images are ignore because the corresponding images are corrupted, default to None\n", "\n", " Returns\n", " -------\n", " Sequence[Mapping]\n", " List of json lines, each lines contains the annotations for a single. This recreates the\n", " format of augmented manifest files that are generated by Amazon SageMaker GroundTruth\n", " labeling jobs\n", " \"\"\"\n", " if channel_id not in (\"training\", \"validation\", \"test\"):\n", " raise ValueError(\n", " f\"Channel identifier must be training, validation or test. The passed values is {channel_id}\"\n", " )\n", " if not path_to_annotation.exists():\n", " raise FileNotFoundError(f\"Annotation file {path_to_annotation} not found\")\n", "\n", " df_annotation = pd.read_csv(\n", " path_to_annotation,\n", " header=0,\n", " names=(\n", " \"image_name\",\n", " \"x1\",\n", " \"y1\",\n", " \"x2\",\n", " \"y2\",\n", " \"class\",\n", " \"image_width\",\n", " \"image_height\",\n", " ),\n", " )\n", "\n", " df_annotation[\"left\"] = df_annotation[\"x1\"]\n", " df_annotation[\"top\"] = df_annotation[\"y1\"]\n", " df_annotation[\"width\"] = df_annotation[\"x2\"] - df_annotation[\"x1\"]\n", " df_annotation[\"height\"] = df_annotation[\"y2\"] - df_annotation[\"y1\"]\n", " df_annotation.drop(columns=[\"x1\", \"x2\", \"y1\", \"y2\"], inplace=True)\n", "\n", " jsonlines = []\n", " for img_id in df_annotation[\"image_name\"].unique():\n", " if img_annotation_to_ignore and img_id in img_annotation_to_ignore:\n", " print(\n", " f\"Annotations for image {img_id} are neglected as the image is corrupted\"\n", " )\n", " continue\n", " img_annotations = df_annotation.loc[df_annotation[\"image_name\"] == img_id, :]\n", " annotations = []\n", " for (\n", " _,\n", " _,\n", " img_width,\n", " img_heigh,\n", " bbox_l,\n", " bbox_t,\n", " bbox_w,\n", " bbox_h,\n", " ) in img_annotations.itertuples(index=False):\n", " annotations.append(\n", " {\n", " \"class_id\": 0,\n", " \"width\": bbox_w,\n", " \"top\": bbox_t,\n", " \"left\": bbox_l,\n", " \"height\": bbox_h,\n", " }\n", " )\n", " jsonline = {\n", " \"sku\": {\n", " \"annotations\": annotations,\n", " \"image_size\": [{\"width\": img_width, \"depth\": 3, \"height\": img_heigh,}],\n", " },\n", " \"sku-metadata\": {\n", " \"job_name\": f\"labeling-job/sku-110k-{channel_id}\",\n", " \"class-map\": {\"0\": \"SKU\"},\n", " \"human-annotated\": \"yes\",\n", " \"objects\": len(annotations) * [{\"confidence\": 0.0}],\n", " \"type\": \"groundtruth/object-detection\",\n", " \"creation-date\": datetime.now()\n", " .replace(second=0, microsecond=0)\n", " .isoformat(),\n", " },\n", " \"source-ref\": f\"s3://{bucket_name}/{data_prefix}/{channel_id}/{img_id}\",\n", " }\n", " jsonlines.append(jsonline)\n", " return jsonlines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "annotation_folder = Path(local_folder) / sku_dataset_dirname / \"annotations\"\n", "channel_to_annotation_path = {\n", " \"training\": annotation_folder / \"annotations_train.csv\",\n", " \"validation\": annotation_folder / \"annotations_val.csv\",\n", " \"test\": annotation_folder / \"annotations_test.csv\",\n", "}\n", "channel_to_annotation = {}\n", "\n", "for channel in channel_to_annotation_path:\n", " annotations = create_annotation_channel(\n", " channel,\n", " channel_to_annotation_path[channel],\n", " bucket,\n", " prefix_data,\n", " CORRUPTED_IMAGES[channel],\n", " )\n", " print(f\"Number of {channel} annotations: {len(annotations)}\")\n", " channel_to_annotation[channel] = annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def upload_annotations(p_annotations, p_channel: str):\n", " rsc_bucket = boto3.resource(\"s3\").Bucket(bucket)\n", "\n", " json_lines = [json.dumps(elem) for elem in p_annotations]\n", " to_write = \"\\n\".join(json_lines)\n", "\n", " with tempfile.NamedTemporaryFile(mode=\"w\") as fid:\n", " fid.write(to_write)\n", " rsc_bucket.upload_file(\n", " fid.name, f\"{prefix_data}/annotations/{p_channel}.manifest\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for channel_id, annotations in channel_to_annotation.items():\n", " upload_annotations(annotations, channel_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習セット、検証セット、テストセットに含まれる画像の数を確認し、アップロードや前処理に失敗した場合にユーザーが学習を開始する前に検出できるようにしましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "channel_to_expected_size = {\n", " \"training\": 8215,\n", " \"validation\": 588,\n", " \"test\": 2934,\n", "}\n", "\n", "prefix_data = \"detectron2/data\"\n", "bucket_rsr = boto3.resource(\"s3\").Bucket(bucket)\n", "for channel_name, exp_nb in channel_to_expected_size.items():\n", " nb_objs = len(\n", " list(bucket_rsr.objects.filter(Prefix=f\"{prefix_data}/{channel_name}\"))\n", " )\n", " assert (\n", " nb_objs == exp_nb\n", " ), f\"The {channel_name} set should have {exp_nb} images but it contains {nb_objs} images\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Amazon SageMaker を使って学習\n", "\n", "SageMaker で学習ジョブを実行するには、以下の作業を行います。\n", "\n", "- 学習コンテナを構築し、Amazon Elastic Container Registry (ECR) にプッシュする。コンテナにはすべてのランタイム依存ファイルと学習スクリプトが含まれる\n", "- 学習クラスタの構成やモデルのハイパーパラメタを含む学習ジョブの構成を定義する\n", "- 学習ジョブをスケジューリングし、その進捗を確認する\n", "\n", "\n", "### 学習用コンテナのビルド\n", "学習コンテナをビルドする前に、Pytorch のベースイメージを取得するための共有 ECR リポジトリと、プライベート ECR リポジトリで認証を行う必要があります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin 763104351884.dkr.ecr.{region}.amazonaws.com\n", "# loging to your private ECR\n", "!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account}.dkr.ecr.{region}.amazonaws.com" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これからビルドするコンテナは、AWS が用意した Pytorch コンテナをベースイメージとして使用します。ベースイメージに Detecton2 の依存関係を追加し、学習スクリプトをコピーします。以下のセルを実行して Dockerfile の内容を確認します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "# execute this cell to review Docker container\n", "pygmentize -l docker Dockerfile.sku110ktraining" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "次に、ローカルで Docker コンテナをビルドし、ECR リポジトリにプッシュすることで、SageMaker が学習時にこのコンテナをコンピュートノードにデプロイできるようにします。以下のコマンドを実行して、コンテナをビルドしてプッシュします。以下のセルの実行が完了するまで 5分以上かかることもあります。作成される Docker イメージのサイズは約5GBです。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "./build_and_push.sh sagemaker-d2-train-sku110k latest Dockerfile.sku110ktraining" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SageMaker 学習ジョブの設定\n", "\n", "設定項目としては以下があります。\n", "\n", "- data configuration: train/test/valのデータセットをどこに保存するかを定義します。\n", "- コンテナの設定\n", "- モデルのハイパーパラメタの設定\n", "- クラスタのサイズやインスタンスの種類、監視するメトリクスなどの学習ジョブのパラメタ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "import boto3\n", "from sagemaker.estimator import Estimator" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Data configuration\n", "\n", "training_channel = f\"s3://{bucket}/{prefix_data}/training/\"\n", "validation_channel = f\"s3://{bucket}/{prefix_data}/validation/\"\n", "test_channel = f\"s3://{bucket}/{prefix_data}/test/\"\n", "\n", "annotation_channel = f\"s3://{bucket}/{prefix_data}/annotations/\"\n", "\n", "classes = [\n", " \"SKU\",\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Container configuration\n", "\n", "container_name = \"sagemaker-d2-train-sku110k\"\n", "container_version = \"latest\"\n", "training_image_uri = (\n", " f\"{account}.dkr.ecr.{region}.amazonaws.com/{container_name}:{container_version}\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Metrics to monitor during training, each metric is scraped from container Stdout\n", "\n", "metrics = [\n", " {\"Name\": \"training:loss\", \"Regex\": \"total_loss: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_cls\", \"Regex\": \"loss_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_box_reg\", \"Regex\": \"loss_box_reg: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_rpn_cls\", \"Regex\": \"loss_rpn_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_rpn_loc\", \"Regex\": \"loss_rpn_loc: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss\", \"Regex\": \"total_val_loss: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_cls\", \"Regex\": \"val_loss_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_box_reg\", \"Regex\": \"val_loss_box_reg: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_rpn_cls\", \"Regex\": \"val_loss_rpn_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_rpn_loc\", \"Regex\": \"val_loss_rpn_loc: ([0-9\\\\.]+)\",},\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Training instance type\n", "\n", "training_instance = \"ml.p3.8xlarge\"\n", "if training_instance.startswith(\"local\"):\n", " training_session = sagemaker.LocalSession()\n", " training_session.config = {\"local\": {\"local_code\": True}}\n", "else:\n", " training_session = sm_session" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習ジョブでは,以下のハイパーパラメタを使用しています.自由に変更して実験してみてください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Model Hyperparameters\n", "\n", "od_algorithm = \"faster_rcnn\" # choose one in (\"faster_rcnn\", \"retinanet\")\n", "training_job_hp = {\n", " # Dataset\n", " \"classes\": json.dumps(classes),\n", " \"dataset-name\": json.dumps(\"sku110k\"),\n", " \"label-name\": json.dumps(\"sku\"),\n", " # Algo specs\n", " \"model-type\": json.dumps(od_algorithm),\n", " \"backbone\": json.dumps(\"R_101_FPN\"),\n", " # Data loader\n", " \"num-iter\": 900,\n", " \"log-period\": 500,\n", " \"batch-size\": 16,\n", " \"num-workers\": 8,\n", " # Optimization\n", " \"lr\": 0.005,\n", " \"lr-schedule\": 3,\n", " # Faster-RCNN specific\n", " \"num-rpn\": 517,\n", " \"bbox-head-pos-fraction\": 0.2,\n", " \"bbox-rpn-pos-fraction\": 0.4,\n", " # Prediction specific\n", " \"nms-thr\": 0.2,\n", " \"pred-thr\": 0.1,\n", " \"det-per-img\": 300,\n", " # Evaluation\n", " \"evaluation-type\": \"fast\",\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SageMaker 学習ジョブでモデルを学習してみましょう。学習状況は SageMaker コンソールのメニューで「トレーニング」→「トレーニングジョブ」でジョブ一覧画面から確認可能です。学習ジョブが完了するまで 20分程度かかります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compile Sagemaker Training job object and start training\n", "\n", "d2_estimator = Estimator(\n", " image_uri=training_image_uri,\n", " role=role,\n", " sagemaker_session=training_session,\n", " instance_count=2,\n", " instance_type=training_instance,\n", " hyperparameters=training_job_hp,\n", " metric_definitions=metrics,\n", " output_path=f\"s3://{bucket}/{prefix_model}\",\n", " base_job_name=f\"detectron2-{od_algorithm.replace('_', '-')}\",\n", ")\n", "\n", "d2_estimator.fit(\n", " {\n", " \"training\": training_channel,\n", " \"validation\": validation_channel,\n", " \"test\": test_channel,\n", " \"annotation\": annotation_channel,\n", " },\n", " wait=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Amazon SageMaker のハイパーパラメタチューニング\n", "\n", "SageMaker SDK には `tuner` モジュールが付属しており、これを使って最適なハイパーパラメタを探すことができます(詳細は [こちら](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html))。ここでは、検証データの loss を最小化することを目的として、異なるハイパーパラメタを用いていくつかの実験を行ってみましょう。\n", "\n", "最適化されるハイパーパラメタを定義する `hparams_range` は自由に変更してください。⚠️ 注意点として、チューニングジョブは複数の学習ジョブを実行します。そのため、チューニングジョブが必要とする計算リソースの量に注意してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import (\n", " CategoricalParameter,\n", " ContinuousParameter,\n", " HyperparameterTuner,\n", " IntegerParameter,\n", ")\n", "\n", "od_algorithm = \"retinanet\" # choose one in (\"faster_rcnn\", \"retinanet\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hparams_range = {\n", " \"lr\": ContinuousParameter(0.0005, 0.1),\n", "}\n", "if od_algorithm == \"faster_rcnn\":\n", " hparams_range.update(\n", " {\n", " \"bbox-rpn-pos-fraction\": ContinuousParameter(0.1, 0.5),\n", " \"bbox-head-pos-fraction\": ContinuousParameter(0.1, 0.5),\n", " }\n", " )\n", "elif od_algorithm == \"retinanet\":\n", " hparams_range.update(\n", " {\n", " \"focal-loss-gamma\": ContinuousParameter(2.5, 5.0),\n", " \"focal-loss-alpha\": ContinuousParameter(0.3, 1.0),\n", " }\n", " )\n", "else:\n", " assert False, f\"{od_algorithm} not supported\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obj_metric_name = \"validation:loss\"\n", "obj_type = \"Minimize\"\n", "metric_definitions = [\n", " {\"Name\": \"training:loss\", \"Regex\": \"total_loss: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_cls\", \"Regex\": \"loss_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"training:loss_box_reg\", \"Regex\": \"loss_box_reg: ([0-9\\\\.]+)\",},\n", " {\"Name\": obj_metric_name, \"Regex\": \"total_val_loss: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_cls\", \"Regex\": \"val_loss_cls: ([0-9\\\\.]+)\",},\n", " {\"Name\": \"validation:loss_box_reg\", \"Regex\": \"val_loss_box_reg: ([0-9\\\\.]+)\",},\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ハイパーパラメタチューニングのための `estimator` を作成します。`use_spot_instances`、`max_run`、`max_wait` のコメントアウトを外すとスポットインスタンスを使ったハイパーパラメタチューニングが可能です。ただし、利用可能なスポットインスタンスがない場合、指定された時間待機したのちジョブが終了しますのでご注意ください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fixed_hparams = {\n", " # Dataset\n", " \"classes\": json.dumps(classes),\n", " \"dataset-name\": json.dumps(\"sku110k\"),\n", " \"label-name\": json.dumps(\"sku\"),\n", " # Algo specs\n", " \"model-type\": json.dumps(od_algorithm),\n", " \"backbone\": json.dumps(\"R_101_FPN\"),\n", " # Data loader\n", " \"num-iter\": 9000,\n", " \"log-period\": 500,\n", " \"batch-size\": 16,\n", " \"num-workers\": 8,\n", " # Optimization\n", " \"lr-schedule\": 3,\n", " # Prediction specific\n", " \"nms-thr\": 0.2,\n", " \"pred-thr\": 0.1,\n", " \"det-per-img\": 300,\n", " # Evaluation\n", " \"evaluation-type\": \"fast\",\n", "}\n", "\n", "hpo_estimator = Estimator(\n", " image_uri=training_image_uri,\n", " role=role,\n", " sagemaker_session=sm_session,\n", " instance_count=1,\n", " instance_type=\"ml.p3.8xlarge\",\n", " hyperparameters=fixed_hparams,\n", " output_path=f\"s3://{bucket}/{prefix_model}\",\n", "# use_spot_instances=True, # Use spot instances to spare a\n", "# max_run=2 * 60 * 60,\n", "# max_wait=3 * 60 * 60,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuner = HyperparameterTuner(\n", " hpo_estimator,\n", " obj_metric_name,\n", " hparams_range,\n", " metric_definitions,\n", " objective_type=obj_type,\n", " max_jobs=2,\n", " max_parallel_jobs=2,\n", " base_tuning_job_name=f\"hpo-d2-{od_algorithm.replace('_', '-')}\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`fit` を実行することで、ハイパーパラメタチューニングが開始します。このノートブックの設定の場合、チューニングジョブが完了するまでに 2時間程度かかります。2つ上のセルの `fixed_hparams` の中の `num-iter` の値を変えることでチューニングジョブにかかる時間を調整できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuner.fit(\n", " inputs={\n", " \"training\": training_channel,\n", " \"validation\": validation_channel,\n", " \"test\": test_channel,\n", " \"annotation\": annotation_channel,\n", " },\n", " wait=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "チューニングジョブが開始したら、以下のセルを実行して状況を確認します。チューニングジョブの開始に少し時間がかかるので、上記セルを実行してから 1分ほど経ってから以下のセルを実行してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's review outcomes of HyperParameter search\n", "\n", "hpo_tuning_job_name = tuner.latest_tuning_job.name\n", "bayes_metrics = sagemaker.HyperparameterTuningJobAnalytics(\n", " hpo_tuning_job_name\n", ").dataframe()\n", "bayes_metrics.sort_values([\"FinalObjectiveValue\"], ascending=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Amazon SageMaker でのモデルデプロイ\n", "\n", "チューニングジョブの完了を待つ間に、推論用コンテナイメージを作成します。モデルの学習と同様に、SageMaker は推論を実行するためにコンテナを使用します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "# execute this cell to review Docker container\n", "pygmentize -l docker Dockerfile.sku110kserving" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下のセルを実行して、イメージ `Dockerfile.sku110kserving` で定義された Dockerコンテナを構築し、ECR にプッシュします。作成される Docker イメージのサイズは約5GBです。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "./build_and_push.sh sagemaker-d2-serve latest Dockerfile.sku110kserving" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "バッチ推論、つまり大規模な画像の塊に対して推論を実行します。これには [SageMaker Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-batch.html) を使用します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**ここからは、HPO (Hyper Parameter Optimizer) ジョブが完了してから実行してください。**チューニングジョブをアタッチし、最適なモデルを取得します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import HyperparameterTuner\n", "\n", "tuning_job_id = tuner.latest_tuning_job.name\n", "attached_tuner = HyperparameterTuner.attach(tuning_job_id)\n", "\n", "best_estimator = attached_tuner.best_estimator()\n", "\n", "best_estimator.latest_training_job.describe()\n", "training_job_artifact = best_estimator.latest_training_job.describe()[\"ModelArtifacts\"][\"S3ModelArtifacts\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "また、モデルアーティファクトの S3 URI を指定することもできます。このオプションを使いたい場合は、以下のコードのコメントを外してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# training_job_artifact = best_estimator.model_data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define parameters of inference container\n", "\n", "serve_container_name = \"sagemaker-d2-serve\"\n", "serve_container_version = \"latest\"\n", "serve_image_uri = f\"{account}.dkr.ecr.{region}.amazonaws.com/{serve_container_name}:{serve_container_version}\"\n", "\n", "inference_output = f\"s3://{bucket}/{prefix_predictions}/{serve_container_name}/{Path(test_channel).name}_channel/{training_job_artifact.split('/')[-3]}\"\n", "inference_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compile SageMaker model object and configure Batch Transform job\n", "\n", "model = PyTorchModel(\n", " name=\"d2-sku110k-model\",\n", " model_data=training_job_artifact,\n", " role=role,\n", " sagemaker_session=sm_session,\n", " entry_point=\"predict_sku110k.py\",\n", " source_dir=\"container_serving\",\n", " image_uri=serve_image_uri,\n", " framework_version=\"1.6.0\",\n", " code_location=f\"s3://{bucket}/{prefix_code}\",\n", ")\n", "\n", "transformer = model.transformer(\n", " instance_count=1,\n", " instance_type=\"ml.p3.2xlarge\", # \"ml.p2.xlarge\", #\n", " output_path=inference_output,\n", " max_payload=16,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下のセルを実行して、バッチ変換ジョブを開始します。バッチ変換ジョブが完了するまでに 20分程度かかります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Start Batch Transform job\n", "\n", "transformer.transform(\n", " data=test_channel,\n", " data_type=\"S3Prefix\",\n", " content_type=\"application/x-image\",\n", " wait=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TransformJobName` にバッチ変換ジョブ名を入れて以下のセルを実行すると、ジョブのステータスを知ることができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "sagemaker_client = boto3.client('sagemaker')\n", "response = sagemaker_client.describe_transform_job(\n", " TransformJobName=''\n", ")\n", "response['TransformJobStatus']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 推論結果の可視化\n", "\n", "バッチ推論ジョブが完了したら、推論結果を可視化してみましょう。ここでは、テスト用画像からランダムに1枚選んで表示します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import matplotlib\n", "import matplotlib.patches as patches\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def key_from_uri(s3_uri: str) -> str:\n", " \"\"\"Get S3 object key from its URI\"\"\"\n", " return \"/\".join(Path(s3_uri).parts[2:])\n", "\n", "\n", "bucket_rsr = boto3.resource(\"s3\").Bucket(bucket)\n", "predict_objs = list(\n", " bucket_rsr.objects.filter(Prefix=key_from_uri(inference_output) + \"/\")\n", ")\n", "img_objs = list(bucket_rsr.objects.filter(Prefix=key_from_uri(test_channel)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "COLORS = [\n", " (0, 200, 0),\n", "]\n", "\n", "\n", "def plot_predictions_on_image(\n", " p_img: np.ndarray, p_preds: Mapping, score_thr: float = 0.5, show=True\n", ") -> plt.Figure:\n", " r\"\"\"Plot bounding boxes predicted by an inference job on the corresponding image\n", "\n", " Parameters\n", " ----------\n", " p_img : np.ndarray\n", " input image used for prediction\n", " p_preds : Mapping\n", " dictionary with bounding boxes, predicted classes and confidence scores\n", " score_thr : float, optional\n", " show bounding boxes whose confidence score is bigger than `score_thr`, by default 0.5\n", " show : bool, optional\n", " show figure if True do not otherwise, by default True\n", "\n", " Returns\n", " -------\n", " plt.Figure\n", " figure handler\n", "\n", " Raises\n", " ------\n", " IOError\n", " If the prediction dictionary `p_preds` does not contain one of the required keys:\n", " `pred_classes`, `pred_boxes` and `scores`\n", " \"\"\"\n", " for required_key in (\"pred_classes\", \"pred_boxes\", \"scores\"):\n", " if required_key not in p_preds:\n", " raise IOError(f\"Missing required key: {required_key}\")\n", "\n", " fig, fig_axis = plt.subplots(1)\n", " fig_axis.imshow(p_img)\n", " for class_id, bbox, score in zip(\n", " p_preds[\"pred_classes\"], p_preds[\"pred_boxes\"], p_preds[\"scores\"]\n", " ):\n", " if score < score_thr:\n", " break # bounding boxes are sorted by confidence score in descending order\n", " rect = patches.Rectangle(\n", " (bbox[0], bbox[1]),\n", " bbox[2] - bbox[0],\n", " bbox[3] - bbox[1],\n", " linewidth=1,\n", " edgecolor=[float(val) / 255 for val in COLORS[class_id]],\n", " facecolor=\"none\",\n", " )\n", " fig_axis.add_patch(rect)\n", " plt.axis(\"off\")\n", " if show:\n", " plt.show()\n", " return fig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下のセルを実行すると、推論結果の画像が表示されます。ランダムで画像が表示されるので、何回か実行して結果を確認してみてください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "matplotlib.rcParams[\"figure.dpi\"] = 300\n", "\n", "sample_id = np.random.randint(0, len(img_objs), 1)[0]\n", "\n", "img_obj = img_objs[sample_id]\n", "pred_obj = predict_objs[sample_id]\n", "\n", "img = np.asarray(Image.open(io.BytesIO(img_obj.get()[\"Body\"].read())))\n", "preds = json.loads(pred_obj.get()[\"Body\"].read().decode(\"utf-8\"))\n", "\n", "sample_fig = plot_predictions_on_image(img, preds, 0.40, True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## リソースの削除\n", "\n", "不要になったら、課金を停止するためにこのノートブックを実行したノートブックインスタンスを削除してください。なお、インスタンスを「停止」しただけでは EBS ボリュームへの課金は継続するので、完全に課金を止めるためにインスタンスを「停止」してから「削除」を実施してください。なお、削除したあとはインスタンスに保存されているファイルなどにアクセスすることはできません。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }