{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5a70731c",
   "metadata": {},
   "source": [
    "# SageMaker JumpStart を用いた LightGBM (分類)のトレーニングと推論\n",
    "* JumpStart では独自のデータを用意するだけで、様々なモデルの学習と出来たモデルの推論ができる\n",
    "* このノートブックでは LightGBM の分類モデルを用いたトレーニングの動かし方を記述する\n",
    "* データについては、AWS が公開しているデータを利用する\n",
    "* SageMaker SDK を使ったトレーニングと推論を記載し、最後に boto3 を使った推論を記載している\n",
    "* このノートブックは `Data Science 2.0` イメージ、`Python 3` カーネルで開いてください\n",
    "\n",
    "## Tabel of Contents\n",
    "* [事前準備](#事前準備)\n",
    "  * [モジュールのインポート](#モジュールのインポート)\n",
    "  * [データ取得](#データ取得)\n",
    "* [SageMaker JumpStart を使って CUI(SageMaker SDK) でトレーニングと推論](#SageMaker-JumpStart-を使って-CUI(SageMaker-SDK)-でトレーニングと推論)\n",
    "  * [トレーニング](#トレーニング)\n",
    "    * [データアップロード](#データアップロード)\n",
    "    * [トレーニングパラメータの取得](#トレーニングパラメータの取得)\n",
    "    * [トレーニングジョブ実行](#トレーニングジョブ実行)\n",
    "  * [推論](#推論)\n",
    "    * [推論パラメータの取得](#トレーニングパラメータの取得)\n",
    "    * [推論エンドポイント作成](#推論エンドポイント作成)\n",
    "* [boto3 で推論](#boto3-で推論)\n",
    "  * [定数やクライアントの設定](#定数やクライアントの設定)\n",
    "  * [モデルと推論コードを tar.gz に固める](#モデルと推論コードを-tar.gz-に固める)\n",
    "  * [boto3 でSageMaker でモデルの作成](#boto3-でSageMaker-でモデルの作成)\n",
    "  * [boto3 でエンドポイントの設定を作成](#boto3-でエンドポイントの設定を作成)\n",
    "  * [boto3 でエンドポイントを作成する](#boto3-でエンドポイントを作成する)\n",
    "  * [boto3 で推論する](#boto3-で推論する)\n",
    "  * [boto3 でエンドポイント他を削除](#boto3-でエンドポイント他を削除)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdfbce04",
   "metadata": {},
   "source": [
    "## 事前準備\n",
    "### モジュールのインポート"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffce17a9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import image_uris, model_uris, script_uris\n",
    "from sagemaker.estimator import Estimator\n",
    "from sagemaker.session import Session\n",
    "from sagemaker import hyperparameters\n",
    "import json\n",
    "import pandas as pd\n",
    "from typing import Final\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97b63a14",
   "metadata": {},
   "source": [
    "### データ取得\n",
    "公開されている分類用データを使う。  \n",
    "mnist の画像をカラム展開されたものであり、最初の列に教師ラベルが格納されている"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9a07ff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_dir: Final[str] = 'classification_data'\n",
    "!if [ -d ./{data_dir} ]; then rm -rf ./{data_dir}/;fi\n",
    "!mkdir ./{data_dir}/\n",
    "!aws s3 sync  s3://jumpstart-cache-prod-us-east-1/training-datasets/tabular_multiclass/ ./{data_dir}/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c8436ab",
   "metadata": {},
   "source": [
    "## SageMaker JumpStart を使って CUI(SageMaker SDK) でトレーニングと推論\n",
    "### トレーニング"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96ee7936",
   "metadata": {},
   "source": [
    "#### データアップロード\n",
    "\n",
    "* トレーニングデータについて\n",
    "    * JumpStart で自分のデータでトレーニングするには予め S3 に配置する(トレーニング実行時に S3 の URI を指定する)\n",
    "* データの持ち方について\n",
    "    * csv 形式でファイル名を data.csv にする必要がある(トレーニングコードが data.csv しか受け付けないようになっている)\n",
    "    * 訓練用データの `train/data.csv` は必ず用意する\n",
    "    * 評価用データの`validation/data.csv` はオプション\n",
    "    * テスト用データの `test/data.csv` はトレーニング時にもちろん使わないがまとめてアップロードしているので副次的にアップロードされる\n",
    "    * ターゲット変数は必ず 0 列目に入れること(トレーニングコードが 0 列目をターゲット変数として認識するように実装されている)\n",
    "* カテゴリー変数について(このデータにカテゴリー変数はない)\n",
    "    * データディレクトリのルートに任意の json ファイルを1つだけ含むことでカテゴリカル変数を扱うことができる\n",
    "    * カテゴリー変数は、0 以上の整数(Int32の範囲内)でエンコードされている必要がある\n",
    "    * カテゴリー変数は列のインデックスで辞書形式でキーに `cat_index_list` で、値に列のインデックスをリスト形式で渡す\n",
    "    * 今回は 1 列目がカテゴリー変数\n",
    "    * 実際の例は[回帰モデル](./lightgbm_regression.ipynb)で利用しているので参照のこと"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4c8c32b",
   "metadata": {},
   "source": [
    "データの確認(JumpStart を動かすのには不要)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217d125f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.read_csv(f'{data_dir}/train/data.csv',header=None).head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa865131",
   "metadata": {},
   "source": [
    "* データアップロードは [upload_data](https://sagemaker.readthedocs.io/en/stable/api/utility/session.html#sagemaker.session.Session.upload_data) メソッドを利用して、ディレクトリまるごと S3 にアップロードする\n",
    "* ここでは SageMaker のデフォルトバケット(`sagemaker-{region}-{account}`にアップロードしているが、任意のバケットを選択するときは `bucket` 引数を使用する\n",
    "* ここで出力される URI は、GUI で入力する値でもある(GUI の場合は、S3 の URI を入力したあと `Train` をクリックすれば学習が開始される  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f18031ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 使うデータを S3 にアップロード\n",
    "input_s3_uri: Final[str] = sagemaker.session.Session().upload_data(\n",
    "    f'./{data_dir}/',\n",
    "    key_prefix = 'sagemaker-jumpstart/lightgbm_classification/data'\n",
    ")\n",
    "print(f'アップロード先 : \\n{input_s3_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a4aac4",
   "metadata": {},
   "source": [
    "#### トレーニングパラメータの取得\n",
    "* JumpStart は予めコンテナやトレーニングコードを用意しているので、そのパラメータを取得する\n",
    "\n",
    "##### 定数の設定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb5f8ddd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# JumpStart で動かすモデルとバージョン、インスタンスタイプと台数を設定\n",
    "model_id: Final[str] = 'lightgbm-classification-model'\n",
    "model_version: Final[str] = '*'\n",
    "training_instance_type: Final[str] = 'ml.m5.xlarge'\n",
    "instance_count: Final[int] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b619554",
   "metadata": {},
   "source": [
    "##### ロール名を取得\n",
    "トレーニングジョブを動かす際に、トレーニングインスタンスに割り当てるロールを取得"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a61840",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# JumpStart で動かすトレーニングジョブにアタッチするロールを取得(Notebook と同一)\n",
    "role: Final[str] = sagemaker.get_execution_role()\n",
    "print(role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76320358",
   "metadata": {},
   "source": [
    "##### Fine-Tune の元となるモデルの URI を取得\n",
    "* JumpStart は Fine-Tune が基本なので、Fine-Tune の元となるモデルの URI を取得\n",
    "* ただし、LightGBM は Fine-Tune するものではないので classification するという設定値だけが格納されている\n",
    "* [sagemaker.model_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/model_uris.html#sagemaker.model_uris.retrieve) メソッドで取得できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3a59fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "base_model_uri: Final[str] = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope=\"training\")\n",
    "print(f'モデルの URI:\\n{base_model_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "615fa805",
   "metadata": {},
   "source": [
    "設定を確認したい場合は下記を実行( JumpStart を動かすのには不要な作業)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9be7ac11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_dir = 'train-lightgbm-classification-model'\n",
    "# !aws s3 cp {base_model_uri} ./\n",
    "# !if [ -d ./{model_dir} ]; then rm -rf {model_dir}/;fi\n",
    "# !mkdir {model_dir}/\n",
    "# !tar zxvf train-lightgbm-classification-model.tar.gz -C ./{model_dir}/\n",
    "# !cat {model_dir}/train-pytorch-lightgbm-lightgbmmulticlass.json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acc3699f",
   "metadata": {},
   "source": [
    "##### トレーニングコードの S3 URI を取得\n",
    "* トレーニングコードは AWS が管理する S3 に格納されており、トレーニングジョブを定義する時に使うため取得する  \n",
    "* [sagemaker.script_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/script_uris.html#sagemaker.script_uris.retrieve) メソッドで取得できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca62d152",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_script_uri: Final[str] = script_uris.retrieve(\n",
    "    model_id=model_id, model_version=model_version, script_scope=\"training\"\n",
    ")\n",
    "print(f'コードの URI:\\n{training_script_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f180ecb",
   "metadata": {},
   "source": [
    "* トレーニングコードを確認したい場合は下記を実行( JumpStart を動かすのには不要な作業)\n",
    "* トレーニングコードをカスタマイズしたい場合はダウンロードして編集する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0431fce",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_script_dir: Final[str] = 'lightgbm_classification_training_script_dir'\n",
    "!aws s3 cp {training_script_uri} ./\n",
    "!if [ -d ./{training_script_dir} ]; then rm -rf ./{training_script_dir}/;fi\n",
    "!mkdir ./{training_script_dir}/\n",
    "!tar zxvf sourcedir.tar.gz -C ./{training_script_dir}/\n",
    "!pygmentize ./{training_script_dir}/transfer_learning.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87e3d3c4",
   "metadata": {},
   "source": [
    "##### トレーニングコンテナイメージの URI を取得\n",
    "* AWS が管理する ECR に格納されており、その URI を取得する\n",
    "* [sagemaker.image_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/image_uris.html#sagemaker.image_uris.retrieve) メソッドで取得できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b693bb4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_image_uri: Final[str] = image_uris.retrieve(\n",
    "    region=None,\n",
    "    framework=None,\n",
    "    image_scope=\"training\",\n",
    "    model_id=model_id,\n",
    "    model_version=model_version,\n",
    "    instance_type=training_instance_type,\n",
    ")\n",
    "print(f'コンテナの URI:\\n{training_image_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6cf5e02",
   "metadata": {},
   "source": [
    "##### デフォルトのハイパーパラメータを取得\n",
    "* [sagemaker.hyperparameters.retrieve_default](https://sagemaker.readthedocs.io/en/stable/api/utility/hyperparameters.html#sagemaker.hyperparameters.retrieve_default) メソッドで取得できる\n",
    "* ハイパーパラメータを変える場合は取得結果の辞書を上書きする"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d133e81",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hps = hyperparameters.retrieve_default(\n",
    "    model_id=model_id,\n",
    "    model_version=model_version,\n",
    ")\n",
    "print(f'ハイパーパラメータ\\n{json.dumps(hps,indent=4)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97662751",
   "metadata": {},
   "source": [
    "#### トレーニングジョブ実行\n",
    "* 通常の SageMaker Training と同じ様に [Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator) クラスから `estimator` インスタンスを生成し、 [fit](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator.fit) メソッドで実行する\n",
    "* 今まで取得してきた設定値を引数に入れて `estimator` インスタンスを生成する\n",
    "* `training_script_uri` について、ローカルで書き換えた場合はローカルのディレクトリを指定する\n",
    "* fit の引数にトレーニングデータの S3 URI を指定する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d009d3da",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "estimator = Estimator(\n",
    "    image_uri=training_image_uri,\n",
    "    source_dir=training_script_uri,\n",
    "    model_uri=base_model_uri,\n",
    "    entry_point=\"transfer_learning.py\",\n",
    "    role=role,\n",
    "    hyperparameters=hps,\n",
    "    instance_count=instance_count,\n",
    "    instance_type=training_instance_type,\n",
    ")\n",
    "estimator.fit({\"training\": input_s3_uri})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf22cb5d",
   "metadata": {},
   "source": [
    "### 推論"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c79566a5",
   "metadata": {},
   "source": [
    "#### 推論パラメータの取得\n",
    "* JumpStart は予めコンテナや推論コードを用意しているので、そのパラメータを取得する\n",
    "\n",
    "##### トレーニングコードの S3 URI を取得\n",
    "* 推論コードは AWS が管理する S3 に格納されており、モデルデプロイに使うため取得する  \n",
    "* [sagemaker.script_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/script_uris.html#sagemaker.script_uris.retrieve) メソッドで取得できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3e671ee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_script_uri: Final[str] = script_uris.retrieve(\n",
    "    model_id=model_id, model_version=model_version, script_scope=\"inference\"\n",
    ")\n",
    "print(f'推論コードのURL:\\n{inference_script_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ba0b82a",
   "metadata": {},
   "source": [
    "* 推論コードを確認したい場合は下記を実行( JumpStart を動かすのには不要な作業)\n",
    "* 推論コードをカスタマイズしたい場合はダウンロードして編集する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e832091",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# inference_script_dir: Final[str] = 'lightgbm_classification_inference_script_dir'\n",
    "# !aws s3 cp {inference_script_uri} ./\n",
    "# !if [ -d ./{inference_script_dir} ]; then rm -rf ./{inference_script_dir}/;fi\n",
    "# !mkdir ./{inference_script_dir}/\n",
    "# !tar zxvf sourcedir.tar.gz -C ./{inference_script_dir}/\n",
    "# !pygmentize ./{inference_script_dir}/inference.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c53b0b4b",
   "metadata": {},
   "source": [
    "##### 推論コンテナイメージの URI を取得\n",
    "* AWS が管理する ECR に格納されており、その URI を取得する\n",
    "* [sagemaker.image_uris.retrieve](https://sagemaker.readthedocs.io/en/stable/api/utility/image_uris.html#sagemaker.image_uris.retrieve) メソッドで取得できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8a80d62",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_image_uri: Final[str] = image_uris.retrieve(\n",
    "    region=None,\n",
    "    framework=None,\n",
    "    image_scope=\"inference\",\n",
    "    model_id=model_id,\n",
    "    model_version=model_version,\n",
    "    instance_type=training_instance_type,\n",
    ")\n",
    "print(f'コンテナの URI:\\n{inference_image_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76cbe69f",
   "metadata": {},
   "source": [
    "#### 推論エンドポイント作成\n",
    "[Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator) の [deploy](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase.deploy) メソッドでエンドポイント作成を行う"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f099e14",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictor = estimator.deploy(\n",
    "    instance_type = 'ml.m5.large',\n",
    "    initial_instance_count  = 1,\n",
    "    entry_point='inference.py',\n",
    "    source_dir=inference_script_uri,\n",
    "    image_uri = inference_image_uri\n",
    "    \n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8746166",
   "metadata": {},
   "source": [
    "#### 推論実行\n",
    "* エンドポイントはデフォルトだと `text/csv` しか受け付けないので(推論コードの `inference.py` と `constants.py` を参照)、呼び出しもと(predictor)側に [CSVSerializer](https://sagemaker.readthedocs.io/en/stable/api/inference/serializers.html#sagemaker.serializers.CSVSerializer) を設定する\n",
    "* `CSVSerializer` を設定すると、API へのリクエスト([predict](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict))時に `content_type='text/csv'` が設定され、また ndarray を渡しても裏側で csv にシリアライズしてくれる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccfaa006",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# csvに変換して、csv 形式でリクエストを投げてくれるようになる\n",
    "predictor.serializer = sagemaker.serializers.CSVSerializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f55cd49d",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# csv でリクエストするパターン\n",
    "np.argmax(json.loads(predictor.predict(pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].to_csv(header=False,index=False)).decode('utf-8'))['probabilities'])\n",
    "# # ndarray でリクエストするパターン\n",
    "# np.argmax(json.loads(predictor.predict(pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].values).decode('utf-8'))['probabilities'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ce3eb42",
   "metadata": {},
   "source": [
    "#### エンドポイント削除\n",
    "* エンドポイントを削除することでインスタンスが停止される\n",
    "* [delete_endpoint](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.delete_endpoint) で削除できる"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a963e645",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05d00843",
   "metadata": {},
   "source": [
    "## boto3 で推論\n",
    "エンドポイント作成や推論は SageMaker SDK ではなく、boto3 からやることも多いのでやり方を紹介"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6312cde8",
   "metadata": {},
   "source": [
    "### 定数やクライアントの設定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653cca9a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "sm_client = boto3.client('sagemaker')\n",
    "smr_client = boto3.client('sagemaker-runtime')\n",
    "endpoint_inservice_waiter = sm_client.get_waiter('endpoint_in_service')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68dd961d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name: Final[str] = 'LightgbmClassification'\n",
    "endpoint_config_name: Final[str] = model_name + 'EndpointConfig'\n",
    "endpoint_name: Final[str] = model_name + 'Endpoint'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a86d9c",
   "metadata": {},
   "source": [
    "### モデルと推論コードを tar.gz に固める\n",
    "推論エンドポイントを立ち上げるためには、SageMaker 上でモデルを登録する必要がある。  \n",
    "ここでいう`モデル`とは、「機械学習モデル+推論コード」を tar.gz の S3 URI と、モデルを動かすコンテナなどを指す。  \n",
    "トレーニングが終わった段階では、lightgbm の学習済モデル(pkl) だけなので、当然推論コードを含まないので、  \n",
    "推論コードを梱包して S3 にアップロードしなおす(SageMaker SDK だと裏側で勝手にやってくれていた)。  \n",
    "  \n",
    "推論コードは、`tar.gz` のルートディレクトリに `code` ディレクトリを配置しその直下に`inference.py`で置くと勝手に読んでくれる。(名前を変えることもできるか環境変数をいじる必要が出てくるのでお勧めしない)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af55ad29",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# トレーニングの記録からモデルの URI を取得して、ローカルにダウンロードする\n",
    "!aws s3 cp {estimator.latest_training_job.describe()['ModelArtifacts']['S3ModelArtifacts']} ./\n",
    "# 先程使った 推論コードをダウンロードする\n",
    "!aws s3 cp {inference_script_uri} ./\n",
    "\n",
    "# モデルを解凍\n",
    "inference_model_dir: Final[str] = 'model'\n",
    "!if [ -d ./{inference_model_dir} ]; then rm -rf ./{inference_model_dir}/;fi\n",
    "!mkdir ./{inference_model_dir}/\n",
    "!tar zxvf ./model.tar.gz -C ./{inference_model_dir}/\n",
    "\n",
    "# コードを追加\n",
    "inference_code_dir: Final[str] = 'code'\n",
    "!if [ -d ./{inference_code_dir} ]; then rm -rf ./{inference_code_dir}/;fi\n",
    "!mkdir ./{inference_code_dir}/\n",
    "!tar zxvf ./sourcedir.tar.gz -C ./{inference_code_dir}/\n",
    "!mv ./code/ model/\n",
    "\n",
    "# 再圧縮\n",
    "!rm ./{inference_model_dir}.tar.gz\n",
    "%cd {inference_model_dir}/\n",
    "!tar zcvf model.tar.gz .\n",
    "%cd ..\n",
    "\n",
    "# モデルとトレーニングコードの tar.gz を S3 にアップロード\n",
    "inference_model_uri: Final[str] = sagemaker.session.Session().upload_data(\n",
    "    f'./{inference_model_dir}/{inference_model_dir}.tar.gz',\n",
    "    key_prefix = 'sagemaker-jumpstart/lightgbm/model'\n",
    ")\n",
    "print(f'アップロード先 : \\n{inference_model_uri}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d1bb0af",
   "metadata": {},
   "source": [
    "### boto3 で SageMaker でモデルの作成\n",
    "アップロードしたモデル `model.tar.gz` と、コンテナイメージを設定する  \n",
    "[create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) メソッドで設定する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffbba6b8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = sm_client.create_model(\n",
    "    ModelName=model_name,\n",
    "    PrimaryContainer={\n",
    "        # SageMaker SDK の時と同じ URI を指定\n",
    "        'Image': inference_image_uri,\n",
    "        # SageMaker SDK の時と同じ URI を指定\n",
    "        'ModelDataUrl': inference_model_uri,\n",
    "    },\n",
    "    # SageMaker SDK の時と同じ role を指定\n",
    "    ExecutionRoleArn=role,\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dba8ecb",
   "metadata": {},
   "source": [
    "### boto3 でエンドポイントの設定を作成\n",
    "使用するモデル、インスタンスの種類と台数などを設定する。  \n",
    "[create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) メソッドで設定する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08b19c5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = sm_client.create_endpoint_config(\n",
    "    EndpointConfigName=endpoint_config_name,\n",
    "    ProductionVariants=[\n",
    "        {\n",
    "            'VariantName': 'AllTrafic',\n",
    "            'ModelName': model_name,\n",
    "            'InitialInstanceCount': 1,\n",
    "            'InstanceType': 'ml.m5.xlarge',\n",
    "        },\n",
    "    ],\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdb2a026",
   "metadata": {},
   "source": [
    "### boto3 でエンドポイントを作成する\n",
    "[create_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint) メソッドで作成する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62ff5f13",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = sm_client.create_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    EndpointConfigName=endpoint_config_name,\n",
    ")\n",
    "# エンドポイントが立ち上がるまで待つ\n",
    "endpoint_inservice_waiter.wait(\n",
    "    EndpointName=endpoint_name,\n",
    "    WaiterConfig={'Delay': 5,}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c127d49",
   "metadata": {},
   "source": [
    "### boto3 で推論する\n",
    "[invoke_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime.html#SageMakerRuntime.Client.invoke_endpoint)で推論を実行できる。  \n",
    "client は `boto3.client('sagemaker')` ではなく、`boto3.client('sagemaker-runtime')`なことに注意。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a68d4a9",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "request_args = {\n",
    "    'EndpointName': endpoint_name,\n",
    "    'ContentType' : 'text/csv',\n",
    "    'Body' : pd.read_csv(f'{data_dir}/test/data.csv',header=None).iloc[0:1,1:].to_csv(header=False, index=False)\n",
    "}\n",
    "response = smr_client.invoke_endpoint(**request_args)\n",
    "np.argmax(json.loads(response['Body'].read())['probabilities'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb141c95",
   "metadata": {},
   "source": [
    "### boto3 でエンドポイント他を削除"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b61a459",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "r = sm_client.delete_endpoint(EndpointName=endpoint_name)\n",
    "r = sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n",
    "r = sm_client.delete_model(ModelName=model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1737fcc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science 2.0)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}