{ "cells": [ { "cell_type": "markdown", "id": "5a70731c", "metadata": {}, "source": [ "# SageMaker JumpStart を用ã„㟠LightGBM (分類)ã®ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã¨æŽ¨è«–\n", "* JumpStart ã§ã¯ç‹¬è‡ªã®ãƒ‡ãƒ¼ã‚¿ã‚’用æ„ã™ã‚‹ã ã‘ã§ã€æ§˜ã€…ãªãƒ¢ãƒ‡ãƒ«ã®å¦ç¿’ã¨å‡ºæ¥ãŸãƒ¢ãƒ‡ãƒ«ã®æŽ¨è«–ãŒã§ãã‚‹\n", "* ã“ã®ãƒŽãƒ¼ãƒˆãƒ–ックã§ã¯ LightGBM ã®åˆ†é¡žãƒ¢ãƒ‡ãƒ«ã‚’用ã„ãŸãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã®å‹•ã‹ã—方を記述ã™ã‚‹\n", "* データã«ã¤ã„ã¦ã¯ã€AWS ãŒå…¬é–‹ã—ã¦ã„るデータを利用ã™ã‚‹\n", "* SageMaker SDK を使ã£ãŸãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã¨æŽ¨è«–を記載ã—ã€æœ€å¾Œã« boto3 を使ã£ãŸæŽ¨è«–を記載ã—ã¦ã„ã‚‹\n", "* ã“ã®ãƒŽãƒ¼ãƒˆãƒ–ック㯠`Data Science 2.0` イメージã€`Python 3` カーãƒãƒ«ã§é–‹ã„ã¦ãã ã•ã„\n", "\n", "## Tabel of Contents\n", "* [äº‹å‰æº–å‚™](#äº‹å‰æº–å‚™)\n", " * [モジュールã®ã‚¤ãƒ³ãƒãƒ¼ãƒˆ](#モジュールã®ã‚¤ãƒ³ãƒãƒ¼ãƒˆ)\n", " * [データå–å¾—](#データå–å¾—)\n", "* [SageMaker JumpStart を使ã£ã¦ CUI(SageMaker SDK) ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã¨æŽ¨è«–](#SageMaker-JumpStart-を使ã£ã¦-CUI(SageMaker-SDK)-ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã¨æŽ¨è«–)\n", " * [トレーニング](#トレーニング)\n", " * [データアップãƒãƒ¼ãƒ‰](#データアップãƒãƒ¼ãƒ‰)\n", " * [トレーニングパラメータã®å–å¾—](#トレーニングパラメータã®å–å¾—)\n", " * [トレーニングジョブ実行](#トレーニングジョブ実行)\n", " * [推論](#推論)\n", " * [推論パラメータã®å–å¾—](#トレーニングパラメータã®å–å¾—)\n", " * [推論エンドãƒã‚¤ãƒ³ãƒˆä½œæˆ](#推論エンドãƒã‚¤ãƒ³ãƒˆä½œæˆ)\n", "* [boto3 ã§æŽ¨è«–](#boto3-ã§æŽ¨è«–)\n", " * [定数やクライアントã®è¨å®š](#定数やクライアントã®è¨å®š)\n", " * [ãƒ¢ãƒ‡ãƒ«ã¨æŽ¨è«–ã‚³ãƒ¼ãƒ‰ã‚’ tar.gz ã«å›ºã‚ã‚‹](#ãƒ¢ãƒ‡ãƒ«ã¨æŽ¨è«–ã‚³ãƒ¼ãƒ‰ã‚’-tar.gz-ã«å›ºã‚ã‚‹)\n", " * [boto3 ã§SageMaker ã§ãƒ¢ãƒ‡ãƒ«ã®ä½œæˆ](#boto3-ã§SageMaker-ã§ãƒ¢ãƒ‡ãƒ«ã®ä½œæˆ)\n", " * [boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã®è¨å®šã‚’作æˆ](#boto3-ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã®è¨å®šã‚’作æˆ)\n", " * [boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã‚’作æˆã™ã‚‹](#boto3-ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã‚’作æˆã™ã‚‹)\n", " * [boto3 ã§æŽ¨è«–ã™ã‚‹](#boto3-ã§æŽ¨è«–ã™ã‚‹)\n", " * [boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆä»–を削除](#boto3-ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆä»–を削除)\n" ] }, { "cell_type": "markdown", "id": "bdfbce04", "metadata": {}, "source": [ "## äº‹å‰æº–å‚™\n", "### モジュールã®ã‚¤ãƒ³ãƒãƒ¼ãƒˆ" ] }, { "cell_type": "code", "execution_count": null, "id": "ffce17a9", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import image_uris, model_uris, script_uris\n", "from sagemaker.estimator import Estimator\n", "from sagemaker.session import Session\n", "from sagemaker import hyperparameters\n", "import json\n", "import pandas as pd\n", "from typing import Final\n", "import numpy as np" ] }, { "cell_type": "markdown", "id": "97b63a14", "metadata": {}, "source": [ "### データå–å¾—\n", "公開ã•れã¦ã„る分類用データを使ã†ã€‚ \n", "mnist ã®ç”»åƒã‚’カラム展開ã•れãŸã‚‚ã®ã§ã‚ã‚Šã€æœ€åˆã®åˆ—ã«æ•™å¸«ãƒ©ãƒ™ãƒ«ãŒæ ¼ç´ã•れã¦ã„ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "bc9a07ff", "metadata": { "tags": [] }, "outputs": [], "source": [ "data_dir: Final[str] = 'classification_data'\n", "!if [ -d ./{data_dir} ]; then rm -rf ./{data_dir}/;fi\n", "!mkdir ./{data_dir}/\n", "!aws s3 sync s3://jumpstart-cache-prod-us-east-1/training-datasets/tabular_multiclass/ ./{data_dir}/" ] }, { "cell_type": "markdown", "id": "1c8436ab", "metadata": {}, "source": [ "## SageMaker JumpStart を使ã£ã¦ CUI(SageMaker SDK) ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã¨æŽ¨è«–\n", "### トレーニング" ] }, { "cell_type": "markdown", "id": "96ee7936", "metadata": {}, "source": [ "#### データアップãƒãƒ¼ãƒ‰\n", "\n", "* トレーニングデータã«ã¤ã„ã¦\n", " * JumpStart ã§è‡ªåˆ†ã®ãƒ‡ãƒ¼ã‚¿ã§ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã™ã‚‹ã«ã¯äºˆã‚ S3 ã«é…ç½®ã™ã‚‹(トレーニング実行時㫠S3 ã® URI を指定ã™ã‚‹)\n", "* ãƒ‡ãƒ¼ã‚¿ã®æŒã¡æ–¹ã«ã¤ã„ã¦\n", " * csv å½¢å¼ã§ãƒ•ァイルåã‚’ data.csv ã«ã™ã‚‹å¿…è¦ãŒã‚ã‚‹(トレーニングコード㌠data.csv ã—ã‹å—ã‘付ã‘ãªã„よã†ã«ãªã£ã¦ã„ã‚‹)\n", " * 訓練用データ㮠`train/data.csv` ã¯å¿…ãšç”¨æ„ã™ã‚‹\n", " * 評価用データã®`validation/data.csv` ã¯ã‚ªãƒ—ション\n", " * テスト用データ㮠`test/data.csv` ã¯ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°æ™‚ã«ã‚‚ã¡ã‚ん使ã‚ãªã„ãŒã¾ã¨ã‚ã¦ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã—ã¦ã„ã‚‹ã®ã§å‰¯æ¬¡çš„ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã•れる\n", " * ターゲット変数ã¯å¿…ãš 0 列目ã«å…¥ã‚Œã‚‹ã“ã¨(トレーニングコード㌠0 列目をターゲット変数ã¨ã—ã¦èªè˜ã™ã‚‹ã‚ˆã†ã«å®Ÿè£…ã•れã¦ã„ã‚‹)\n", "* カテゴリー変数ã«ã¤ã„ã¦(ã“ã®ãƒ‡ãƒ¼ã‚¿ã«ã‚«ãƒ†ã‚´ãƒªãƒ¼å¤‰æ•°ã¯ãªã„)\n", " * データディレクトリã®ãƒ«ãƒ¼ãƒˆã«ä»»æ„ã® json ファイルを1ã¤ã ã‘å«ã‚€ã“ã¨ã§ã‚«ãƒ†ã‚´ãƒªã‚«ãƒ«å¤‰æ•°ã‚’扱ã†ã“ã¨ãŒã§ãã‚‹\n", " * カテゴリー変数ã¯ã€0 ä»¥ä¸Šã®æ•´æ•°(Int32ã®ç¯„囲内)ã§ã‚¨ãƒ³ã‚³ãƒ¼ãƒ‰ã•れã¦ã„ã‚‹å¿…è¦ãŒã‚ã‚‹\n", " * カテゴリー変数ã¯åˆ—ã®ã‚¤ãƒ³ãƒ‡ãƒƒã‚¯ã‚¹ã§è¾žæ›¸å½¢å¼ã§ã‚ー㫠`cat_index_list` ã§ã€å€¤ã«åˆ—ã®ã‚¤ãƒ³ãƒ‡ãƒƒã‚¯ã‚¹ã‚’リスト形å¼ã§æ¸¡ã™\n", " * 今回㯠1 列目ãŒã‚«ãƒ†ã‚´ãƒªãƒ¼å¤‰æ•°\n", " * 実際ã®ä¾‹ã¯[回帰モデル](./lightgbm_regression.ipynb)ã§åˆ©ç”¨ã—ã¦ã„ã‚‹ã®ã§å‚ç…§ã®ã“ã¨" ] }, { "cell_type": "markdown", "id": "a4c8c32b", "metadata": {}, "source": [ "データã®ç¢ºèª(JumpStart ã‚’å‹•ã‹ã™ã®ã«ã¯ä¸è¦)" ] }, { "cell_type": "code", "execution_count": null, "id": "217d125f", "metadata": {}, "outputs": [], "source": [ "# pd.read_csv(f'{data_dir}/train/data.csv',header=None).head()" ] }, { "cell_type": "markdown", "id": "aa865131", "metadata": {}, "source": [ "* データアップãƒãƒ¼ãƒ‰ã¯ [upload_data](https://sagemaker.readthedocs.io/en/stable/api/utility/session.html#sagemaker.session.Session.upload_data) メソッドを利用ã—ã¦ã€ãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã¾ã‚‹ã”㨠S3 ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã™ã‚‹\n", "* ã“ã“ã§ã¯ SageMaker ã®ãƒ‡ãƒ•ォルトãƒã‚±ãƒƒãƒˆ(`sagemaker-{region}-{account}`ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã—ã¦ã„ã‚‹ãŒã€ä»»æ„ã®ãƒã‚±ãƒƒãƒˆã‚’é¸æŠžã™ã‚‹ã¨ã㯠`bucket` 引数を使用ã™ã‚‹\n", "* ã“ã“ã§å‡ºåŠ›ã•れる URI ã¯ã€GUI ã§å…¥åŠ›ã™ã‚‹å€¤ã§ã‚‚ã‚る(GUI ã®å ´åˆã¯ã€S3 ã® URI を入力ã—ãŸã‚㨠`Train` をクリックã™ã‚Œã°å¦ç¿’ãŒé–‹å§‹ã•れる " ] }, { "cell_type": "code", "execution_count": null, "id": "f18031ca", "metadata": { "tags": [] }, "outputs": [], "source": [ "# 使ã†ãƒ‡ãƒ¼ã‚¿ã‚’ S3 ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰\n", "input_s3_uri: Final[str] = sagemaker.session.Session().upload_data(\n", " f'./{data_dir}/',\n", " key_prefix = 'sagemaker-jumpstart/lightgbm_classification/data'\n", ")\n", "print(f'アップãƒãƒ¼ãƒ‰å…ˆ : \\n{input_s3_uri}')" ] }, { "cell_type": "markdown", "id": "e7a4aac4", "metadata": {}, "source": [ "#### トレーニングパラメータã®å–å¾—\n", "* JumpStart ã¯äºˆã‚コンテナやトレーニングコードを用æ„ã—ã¦ã„ã‚‹ã®ã§ã€ãã®ãƒ‘ラメータをå–å¾—ã™ã‚‹\n", "\n", "##### 定数ã®è¨å®š" ] }, { "cell_type": "code", "execution_count": null, "id": "eb5f8ddd", "metadata": { "tags": [] }, "outputs": [], "source": [ "# JumpStart ã§å‹•ã‹ã™ãƒ¢ãƒ‡ãƒ«ã¨ãƒãƒ¼ã‚¸ãƒ§ãƒ³ã€ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ã‚¿ã‚¤ãƒ—ã¨å°æ•°ã‚’è¨å®š\n", "model_id: Final[str] = 'lightgbm-classification-model'\n", "model_version: Final[str] = '*'\n", "training_instance_type: Final[str] = 'ml.m5.xlarge'\n", "instance_count: Final[int] = 1" ] }, { "cell_type": "markdown", "id": "4b619554", "metadata": {}, "source": [ "##### ãƒãƒ¼ãƒ«åã‚’å–å¾—\n", "トレーニングジョブを動ã‹ã™éš›ã«ã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ã«å‰²ã‚Šå½“ã¦ã‚‹ãƒãƒ¼ãƒ«ã‚’å–å¾—" ] }, { "cell_type": "code", "execution_count": null, "id": "f8a61840", "metadata": { "tags": [] }, "outputs": [], "source": [ "# JumpStart ã§å‹•ã‹ã™ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¸ãƒ§ãƒ–ã«ã‚¢ã‚¿ãƒƒãƒã™ã‚‹ãƒãƒ¼ãƒ«ã‚’å–å¾—(Notebook ã¨åŒä¸€)\n", "role: Final[str] = sagemaker.get_execution_role()\n", "print(role)" ] }, { "cell_type": "markdown", "id": "76320358", "metadata": {}, "source": [ "##### Fine-Tune ã®å…ƒã¨ãªã‚‹ãƒ¢ãƒ‡ãƒ«ã® URI ã‚’å–å¾—\n", "* JumpStart 㯠Fine-Tune ãŒåŸºæœ¬ãªã®ã§ã€Fine-Tune ã®å…ƒã¨ãªã‚‹ãƒ¢ãƒ‡ãƒ«ã® URI ã‚’å–å¾—\n", "* ãŸã ã—ã€LightGBM 㯠Fine-Tune ã™ã‚‹ã‚‚ã®ã§ã¯ãªã„ã®ã§ classification ã™ã‚‹ã¨ã„ã†è¨å®šå€¤ã ã‘ãŒæ ¼ç´ã•れã¦ã„ã‚‹\n", "* [sagemaker.model_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/model_uris.html#sagemaker.model_uris.retrieve) メソッドã§å–å¾—ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "5f3a59fb", "metadata": { "tags": [] }, "outputs": [], "source": [ "base_model_uri: Final[str] = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope=\"training\")\n", "print(f'モデル㮠URI:\\n{base_model_uri}')" ] }, { "cell_type": "markdown", "id": "615fa805", "metadata": {}, "source": [ "è¨å®šã‚’確èªã—ãŸã„å ´åˆã¯ä¸‹è¨˜ã‚’実行( JumpStart ã‚’å‹•ã‹ã™ã®ã«ã¯ä¸è¦ãªä½œæ¥)" ] }, { "cell_type": "code", "execution_count": null, "id": "9be7ac11", "metadata": {}, "outputs": [], "source": [ "# model_dir = 'train-lightgbm-classification-model'\n", "# !aws s3 cp {base_model_uri} ./\n", "# !if [ -d ./{model_dir} ]; then rm -rf {model_dir}/;fi\n", "# !mkdir {model_dir}/\n", "# !tar zxvf train-lightgbm-classification-model.tar.gz -C ./{model_dir}/\n", "# !cat {model_dir}/train-pytorch-lightgbm-lightgbmmulticlass.json" ] }, { "cell_type": "markdown", "id": "acc3699f", "metadata": {}, "source": [ "##### トレーニングコード㮠S3 URI ã‚’å–å¾—\n", "* トレーニングコード㯠AWS ãŒç®¡ç†ã™ã‚‹ S3 ã«æ ¼ç´ã•れã¦ãŠã‚Šã€ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚¸ãƒ§ãƒ–を定義ã™ã‚‹æ™‚ã«ä½¿ã†ãŸã‚å–å¾—ã™ã‚‹ \n", "* [sagemaker.script_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/script_uris.html#sagemaker.script_uris.retrieve) メソッドã§å–å¾—ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "ca62d152", "metadata": { "tags": [] }, "outputs": [], "source": [ "training_script_uri: Final[str] = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"training\"\n", ")\n", "print(f'コード㮠URI:\\n{training_script_uri}')" ] }, { "cell_type": "markdown", "id": "4f180ecb", "metadata": {}, "source": [ "* トレーニングコードを確èªã—ãŸã„å ´åˆã¯ä¸‹è¨˜ã‚’実行( JumpStart ã‚’å‹•ã‹ã™ã®ã«ã¯ä¸è¦ãªä½œæ¥)\n", "* トレーニングコードをカスタマイズã—ãŸã„å ´åˆã¯ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã—ã¦ç·¨é›†ã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "f0431fce", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "training_script_dir: Final[str] = 'lightgbm_classification_training_script_dir'\n", "!aws s3 cp {training_script_uri} ./\n", "!if [ -d ./{training_script_dir} ]; then rm -rf ./{training_script_dir}/;fi\n", "!mkdir ./{training_script_dir}/\n", "!tar zxvf sourcedir.tar.gz -C ./{training_script_dir}/\n", "!pygmentize ./{training_script_dir}/transfer_learning.py" ] }, { "cell_type": "markdown", "id": "87e3d3c4", "metadata": {}, "source": [ "##### トレーニングコンテナイメージ㮠URI ã‚’å–å¾—\n", "* AWS ãŒç®¡ç†ã™ã‚‹ ECR ã«æ ¼ç´ã•れã¦ãŠã‚Šã€ãã® URI ã‚’å–å¾—ã™ã‚‹\n", "* [sagemaker.image_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/image_uris.html#sagemaker.image_uris.retrieve) メソッドã§å–å¾—ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "1b693bb4", "metadata": { "tags": [] }, "outputs": [], "source": [ "training_image_uri: Final[str] = image_uris.retrieve(\n", " region=None,\n", " framework=None,\n", " image_scope=\"training\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=training_instance_type,\n", ")\n", "print(f'コンテナ㮠URI:\\n{training_image_uri}')" ] }, { "cell_type": "markdown", "id": "e6cf5e02", "metadata": {}, "source": [ "##### デフォルトã®ãƒã‚¤ãƒ‘ーパラメータをå–å¾—\n", "* [sagemaker.hyperparameters.retrieve_default](https://sagemaker.readthedocs.io/en/stable/api/utility/hyperparameters.html#sagemaker.hyperparameters.retrieve_default) メソッドã§å–å¾—ã§ãã‚‹\n", "* ãƒã‚¤ãƒ‘ーパラメータを変ãˆã‚‹å ´åˆã¯å–å¾—çµæžœã®è¾žæ›¸ã‚’上書ãã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "8d133e81", "metadata": { "tags": [] }, "outputs": [], "source": [ "hps = hyperparameters.retrieve_default(\n", " model_id=model_id,\n", " model_version=model_version,\n", ")\n", "print(f'ãƒã‚¤ãƒ‘ーパラメータ\\n{json.dumps(hps,indent=4)}')" ] }, { "cell_type": "markdown", "id": "97662751", "metadata": {}, "source": [ "#### トレーニングジョブ実行\n", "* 通常㮠SageMaker Training ã¨åŒã˜æ§˜ã« [Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator) クラスã‹ã‚‰ `estimator` インスタンスを生æˆã—〠[fit](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator.fit) メソッドã§å®Ÿè¡Œã™ã‚‹\n", "* 今ã¾ã§å–å¾—ã—ã¦ããŸè¨å®šå€¤ã‚’引数ã«å…¥ã‚Œã¦ `estimator` インスタンスを生æˆã™ã‚‹\n", "* `training_script_uri` ã«ã¤ã„ã¦ã€ãƒãƒ¼ã‚«ãƒ«ã§æ›¸ãæ›ãˆãŸå ´åˆã¯ãƒãƒ¼ã‚«ãƒ«ã®ãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã‚’指定ã™ã‚‹\n", "* fit ã®å¼•æ•°ã«ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ãƒ‡ãƒ¼ã‚¿ã® S3 URI を指定ã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "d009d3da", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "estimator = Estimator(\n", " image_uri=training_image_uri,\n", " source_dir=training_script_uri,\n", " model_uri=base_model_uri,\n", " entry_point=\"transfer_learning.py\",\n", " role=role,\n", " hyperparameters=hps,\n", " instance_count=instance_count,\n", " instance_type=training_instance_type,\n", ")\n", "estimator.fit({\"training\": input_s3_uri})\n" ] }, { "cell_type": "markdown", "id": "bf22cb5d", "metadata": {}, "source": [ "### 推論" ] }, { "cell_type": "markdown", "id": "c79566a5", "metadata": {}, "source": [ "#### 推論パラメータã®å–å¾—\n", "* JumpStart ã¯äºˆã‚コンテナや推論コードを用æ„ã—ã¦ã„ã‚‹ã®ã§ã€ãã®ãƒ‘ラメータをå–å¾—ã™ã‚‹\n", "\n", "##### トレーニングコード㮠S3 URI ã‚’å–å¾—\n", "* 推論コード㯠AWS ãŒç®¡ç†ã™ã‚‹ S3 ã«æ ¼ç´ã•れã¦ãŠã‚Šã€ãƒ¢ãƒ‡ãƒ«ãƒ‡ãƒ—ãƒã‚¤ã«ä½¿ã†ãŸã‚å–å¾—ã™ã‚‹ \n", "* [sagemaker.script_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/script_uris.html#sagemaker.script_uris.retrieve) メソッドã§å–å¾—ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "c3e671ee", "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_script_uri: Final[str] = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"inference\"\n", ")\n", "print(f'推論コードã®URL:\\n{inference_script_uri}')" ] }, { "cell_type": "markdown", "id": "3ba0b82a", "metadata": {}, "source": [ "* 推論コードを確èªã—ãŸã„å ´åˆã¯ä¸‹è¨˜ã‚’実行( JumpStart ã‚’å‹•ã‹ã™ã®ã«ã¯ä¸è¦ãªä½œæ¥)\n", "* 推論コードをカスタマイズã—ãŸã„å ´åˆã¯ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã—ã¦ç·¨é›†ã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "2e832091", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# inference_script_dir: Final[str] = 'lightgbm_classification_inference_script_dir'\n", "# !aws s3 cp {inference_script_uri} ./\n", "# !if [ -d ./{inference_script_dir} ]; then rm -rf ./{inference_script_dir}/;fi\n", "# !mkdir ./{inference_script_dir}/\n", "# !tar zxvf sourcedir.tar.gz -C ./{inference_script_dir}/\n", "# !pygmentize ./{inference_script_dir}/inference.py" ] }, { "cell_type": "markdown", "id": "c53b0b4b", "metadata": {}, "source": [ "##### 推論コンテナイメージ㮠URI ã‚’å–å¾—\n", "* AWS ãŒç®¡ç†ã™ã‚‹ ECR ã«æ ¼ç´ã•れã¦ãŠã‚Šã€ãã® URI ã‚’å–å¾—ã™ã‚‹\n", "* [sagemaker.image_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/image_uris.html#sagemaker.image_uris.retrieve) メソッドã§å–å¾—ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "d8a80d62", "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_image_uri: Final[str] = image_uris.retrieve(\n", " region=None,\n", " framework=None,\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=training_instance_type,\n", ")\n", "print(f'コンテナ㮠URI:\\n{inference_image_uri}')" ] }, { "cell_type": "markdown", "id": "76cbe69f", "metadata": {}, "source": [ "#### 推論エンドãƒã‚¤ãƒ³ãƒˆä½œæˆ\n", "[Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator) ã® [deploy](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase.deploy) メソッドã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆä½œæˆã‚’行ã†" ] }, { "cell_type": "code", "execution_count": null, "id": "3f099e14", "metadata": { "tags": [] }, "outputs": [], "source": [ "predictor = estimator.deploy(\n", " instance_type = 'ml.m5.large',\n", " initial_instance_count = 1,\n", " entry_point='inference.py',\n", " source_dir=inference_script_uri,\n", " image_uri = inference_image_uri\n", " \n", ")" ] }, { "cell_type": "markdown", "id": "c8746166", "metadata": {}, "source": [ "#### 推論実行\n", "* エンドãƒã‚¤ãƒ³ãƒˆã¯ãƒ‡ãƒ•ォルトã 㨠`text/csv` ã—ã‹å—ã‘付ã‘ãªã„ã®ã§(推論コード㮠`inference.py` 㨠`constants.py` ã‚’å‚ç…§)ã€å‘¼ã³å‡ºã—ã‚‚ã¨(predictor)å´ã« [CSVSerializer](https://sagemaker.readthedocs.io/en/stable/api/inference/serializers.html#sagemaker.serializers.CSVSerializer) ã‚’è¨å®šã™ã‚‹\n", "* `CSVSerializer` ã‚’è¨å®šã™ã‚‹ã¨ã€API ã¸ã®ãƒªã‚¯ã‚¨ã‚¹ãƒˆ([predict](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict))時㫠`content_type='text/csv'` ãŒè¨å®šã•れã€ã¾ãŸ ndarray を渡ã—ã¦ã‚‚è£å´ã§ csv ã«ã‚·ãƒªã‚¢ãƒ©ã‚¤ã‚ºã—ã¦ãれる" ] }, { "cell_type": "code", "execution_count": null, "id": "ccfaa006", "metadata": { "tags": [] }, "outputs": [], "source": [ "# csvã«å¤‰æ›ã—ã¦ã€csv å½¢å¼ã§ãƒªã‚¯ã‚¨ã‚¹ãƒˆã‚’投ã’ã¦ãれるよã†ã«ãªã‚‹\n", "predictor.serializer = sagemaker.serializers.CSVSerializer()" ] }, { "cell_type": "code", "execution_count": null, "id": "f55cd49d", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# csv ã§ãƒªã‚¯ã‚¨ã‚¹ãƒˆã™ã‚‹ãƒ‘ターン\n", "np.argmax(json.loads(predictor.predict(pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].to_csv(header=False,index=False)).decode('utf-8'))['probabilities'])\n", "# # ndarray ã§ãƒªã‚¯ã‚¨ã‚¹ãƒˆã™ã‚‹ãƒ‘ターン\n", "# np.argmax(json.loads(predictor.predict(pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].values).decode('utf-8'))['probabilities'])" ] }, { "cell_type": "markdown", "id": "7ce3eb42", "metadata": {}, "source": [ "#### エンドãƒã‚¤ãƒ³ãƒˆå‰Šé™¤\n", "* エンドãƒã‚¤ãƒ³ãƒˆã‚’削除ã™ã‚‹ã“ã¨ã§ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ãŒåœæ¢ã•れる\n", "* [delete_endpoint](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.delete_endpoint) ã§å‰Šé™¤ã§ãã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "a963e645", "metadata": { "tags": [] }, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "id": "05d00843", "metadata": {}, "source": [ "## boto3 ã§æŽ¨è«–\n", "エンドãƒã‚¤ãƒ³ãƒˆä½œæˆã‚„推論㯠SageMaker SDK ã§ã¯ãªãã€boto3 ã‹ã‚‰ã‚„ã‚‹ã“ã¨ã‚‚多ã„ã®ã§ã‚„り方を紹介" ] }, { "cell_type": "markdown", "id": "6312cde8", "metadata": {}, "source": [ "### 定数やクライアントã®è¨å®š" ] }, { "cell_type": "code", "execution_count": null, "id": "653cca9a", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "sm_client = boto3.client('sagemaker')\n", "smr_client = boto3.client('sagemaker-runtime')\n", "endpoint_inservice_waiter = sm_client.get_waiter('endpoint_in_service')" ] }, { "cell_type": "code", "execution_count": null, "id": "68dd961d", "metadata": { "tags": [] }, "outputs": [], "source": [ "model_name: Final[str] = 'LightgbmClassification'\n", "endpoint_config_name: Final[str] = model_name + 'EndpointConfig'\n", "endpoint_name: Final[str] = model_name + 'Endpoint'" ] }, { "cell_type": "markdown", "id": "f6a86d9c", "metadata": {}, "source": [ "### ãƒ¢ãƒ‡ãƒ«ã¨æŽ¨è«–ã‚³ãƒ¼ãƒ‰ã‚’ tar.gz ã«å›ºã‚ã‚‹\n", "推論エンドãƒã‚¤ãƒ³ãƒˆã‚’ç«‹ã¡ä¸Šã’ã‚‹ãŸã‚ã«ã¯ã€SageMaker 上ã§ãƒ¢ãƒ‡ãƒ«ã‚’登録ã™ã‚‹å¿…è¦ãŒã‚る。 \n", "ã“ã“ã§ã„ã†`モデル`ã¨ã¯ã€ã€Œæ©Ÿæ¢°å¦ç¿’モデル+推論コードã€ã‚’ tar.gz ã® S3 URI ã¨ã€ãƒ¢ãƒ‡ãƒ«ã‚’å‹•ã‹ã™ã‚³ãƒ³ãƒ†ãƒŠãªã©ã‚’指ã™ã€‚ \n", "トレーニングãŒçµ‚ã‚ã£ãŸæ®µéšŽã§ã¯ã€lightgbm ã®å¦ç¿’済モデル(pkl) ã ã‘ãªã®ã§ã€å½“然推論コードをå«ã¾ãªã„ã®ã§ã€ \n", "推論コードを梱包ã—㦠S3 ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰ã—ãªãŠã™(SageMaker SDK ã ã¨è£å´ã§å‹æ‰‹ã«ã‚„ã£ã¦ãれã¦ã„ãŸ)。 \n", " \n", "推論コードã¯ã€`tar.gz` ã®ãƒ«ãƒ¼ãƒˆãƒ‡ã‚£ãƒ¬ã‚¯ãƒˆãƒªã« `code` ディレクトリをé…ç½®ã—ãã®ç›´ä¸‹ã«`inference.py`ã§ç½®ãã¨å‹æ‰‹ã«èªã‚“ã§ãれる。(åå‰ã‚’変ãˆã‚‹ã“ã¨ã‚‚ã§ãã‚‹ã‹ç’°å¢ƒå¤‰æ•°ã‚’ã„ã˜ã‚‹å¿…è¦ãŒå‡ºã¦ãã‚‹ã®ã§ãŠå‹§ã‚ã—ãªã„)" ] }, { "cell_type": "code", "execution_count": null, "id": "af55ad29", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# トレーニングã®è¨˜éŒ²ã‹ã‚‰ãƒ¢ãƒ‡ãƒ«ã® URI ã‚’å–å¾—ã—ã¦ã€ãƒãƒ¼ã‚«ãƒ«ã«ãƒ€ã‚¦ãƒ³ãƒãƒ¼ãƒ‰ã™ã‚‹\n", "!aws s3 cp {estimator.latest_training_job.describe()['ModelArtifacts']['S3ModelArtifacts']} ./\n", "# 先程使ã£ãŸ 推論コードをダウンãƒãƒ¼ãƒ‰ã™ã‚‹\n", "!aws s3 cp {inference_script_uri} ./\n", "\n", "# モデルを解å‡\n", "inference_model_dir: Final[str] = 'model'\n", "!if [ -d ./{inference_model_dir} ]; then rm -rf ./{inference_model_dir}/;fi\n", "!mkdir ./{inference_model_dir}/\n", "!tar zxvf ./model.tar.gz -C ./{inference_model_dir}/\n", "\n", "# ã‚³ãƒ¼ãƒ‰ã‚’è¿½åŠ \n", "inference_code_dir: Final[str] = 'code'\n", "!if [ -d ./{inference_code_dir} ]; then rm -rf ./{inference_code_dir}/;fi\n", "!mkdir ./{inference_code_dir}/\n", "!tar zxvf ./sourcedir.tar.gz -C ./{inference_code_dir}/\n", "!mv ./code/ model/\n", "\n", "# å†åœ§ç¸®\n", "!rm ./{inference_model_dir}.tar.gz\n", "%cd {inference_model_dir}/\n", "!tar zcvf model.tar.gz .\n", "%cd ..\n", "\n", "# モデルã¨ãƒˆãƒ¬ãƒ¼ãƒ‹ãƒ³ã‚°ã‚³ãƒ¼ãƒ‰ã® tar.gz ã‚’ S3 ã«ã‚¢ãƒƒãƒ—ãƒãƒ¼ãƒ‰\n", "inference_model_uri: Final[str] = sagemaker.session.Session().upload_data(\n", " f'./{inference_model_dir}/{inference_model_dir}.tar.gz',\n", " key_prefix = 'sagemaker-jumpstart/lightgbm/model'\n", ")\n", "print(f'アップãƒãƒ¼ãƒ‰å…ˆ : \\n{inference_model_uri}')" ] }, { "cell_type": "markdown", "id": "8d1bb0af", "metadata": {}, "source": [ "### boto3 ã§ SageMaker ã§ãƒ¢ãƒ‡ãƒ«ã®ä½œæˆ\n", "アップãƒãƒ¼ãƒ‰ã—ãŸãƒ¢ãƒ‡ãƒ« `model.tar.gz` ã¨ã€ã‚³ãƒ³ãƒ†ãƒŠã‚¤ãƒ¡ãƒ¼ã‚¸ã‚’è¨å®šã™ã‚‹ \n", "[create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) メソッドã§è¨å®šã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "ffbba6b8", "metadata": { "tags": [] }, "outputs": [], "source": [ "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " # SageMaker SDK ã®æ™‚ã¨åŒã˜ URI を指定\n", " 'Image': inference_image_uri,\n", " # SageMaker SDK ã®æ™‚ã¨åŒã˜ URI を指定\n", " 'ModelDataUrl': inference_model_uri,\n", " },\n", " # SageMaker SDK ã®æ™‚ã¨åŒã˜ role を指定\n", " ExecutionRoleArn=role,\n", ")\n", "print(response)" ] }, { "cell_type": "markdown", "id": "7dba8ecb", "metadata": {}, "source": [ "### boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã®è¨å®šã‚’作æˆ\n", "使用ã™ã‚‹ãƒ¢ãƒ‡ãƒ«ã€ã‚¤ãƒ³ã‚¹ã‚¿ãƒ³ã‚¹ã®ç¨®é¡žã¨å°æ•°ãªã©ã‚’è¨å®šã™ã‚‹ã€‚ \n", "[create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) メソッドã§è¨å®šã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "f08b19c5", "metadata": { "tags": [] }, "outputs": [], "source": [ "response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " 'VariantName': 'AllTrafic',\n", " 'ModelName': model_name,\n", " 'InitialInstanceCount': 1,\n", " 'InstanceType': 'ml.m5.xlarge',\n", " },\n", " ],\n", ")\n", "print(response)" ] }, { "cell_type": "markdown", "id": "fdb2a026", "metadata": {}, "source": [ "### boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆã‚’作æˆã™ã‚‹\n", "[create_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint) メソッドã§ä½œæˆã™ã‚‹" ] }, { "cell_type": "code", "execution_count": null, "id": "62ff5f13", "metadata": { "tags": [] }, "outputs": [], "source": [ "response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")\n", "# エンドãƒã‚¤ãƒ³ãƒˆãŒç«‹ã¡ä¸ŠãŒã‚‹ã¾ã§å¾…ã¤\n", "endpoint_inservice_waiter.wait(\n", " EndpointName=endpoint_name,\n", " WaiterConfig={'Delay': 5,}\n", ")" ] }, { "cell_type": "markdown", "id": "3c127d49", "metadata": {}, "source": [ "### boto3 ã§æŽ¨è«–ã™ã‚‹\n", "[invoke_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime.html#SageMakerRuntime.Client.invoke_endpoint)ã§æŽ¨è«–ã‚’å®Ÿè¡Œã§ãる。 \n", "client 㯠`boto3.client('sagemaker')` ã§ã¯ãªãã€`boto3.client('sagemaker-runtime')`ãªã“ã¨ã«æ³¨æ„。" ] }, { "cell_type": "code", "execution_count": null, "id": "4a68d4a9", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "request_args = {\n", " 'EndpointName': endpoint_name,\n", " 'ContentType' : 'text/csv',\n", " 'Body' : pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].to_csv(header=False, index=False)\n", "}\n", "response = smr_client.invoke_endpoint(**request_args)\n", "np.argmax(json.loads(response['Body'].read())['probabilities'])" ] }, { "cell_type": "markdown", "id": "eb141c95", "metadata": {}, "source": [ "### boto3 ã§ã‚¨ãƒ³ãƒ‰ãƒã‚¤ãƒ³ãƒˆä»–を削除" ] }, { "cell_type": "code", "execution_count": null, "id": "5b61a459", "metadata": { "tags": [] }, "outputs": [], "source": [ "r = sm_client.delete_endpoint(EndpointName=endpoint_name)\n", "r = sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "r = sm_client.delete_model(ModelName=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "e1737fcc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 2.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }