{ "cells": [ { "cell_type": "markdown", "id": "64597428", "metadata": {}, "source": [ "# [SageMaker PyTorch Container](https://github.com/aws/deep-learning-containers/tree/master/pytorch/inference/docker) で推論チュートリアル\n" ] }, { "cell_type": "markdown", "id": "9e05a97a", "metadata": {}, "source": [ "## 準備" ] }, { "cell_type": "markdown", "id": "179d8f2d", "metadata": {}, "source": [ "### モジュールのインポート、定数の設定、boto3 クライアントの設定、ロールの取得" ] }, { "cell_type": "code", "execution_count": null, "id": "27e75a13", "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from typing import Final\n", "from sagemaker.pytorch import PyTorchModel\n", "from sagemaker.async_inference import AsyncInferenceConfig\n", "import os, boto3, json, numpy as np\n", "from io import BytesIO\n", "from time import sleep\n", "from uuid import uuid4\n", "smr_client:Final = boto3.client('sagemaker-runtime')\n", "sm_client:Final = boto3.client('sagemaker')\n", "s3_client:Final = boto3.client('s3')\n", "endpoint_inservice_waiter:Final = sm_client.get_waiter('endpoint_in_service')\n", "role: Final[str] = sagemaker.get_execution_role()\n", "region: Final[str] = sagemaker.Session().boto_region_name\n", "bucket: Final[str] = sagemaker.Session().default_bucket()" ] }, { "cell_type": "markdown", "id": "a59f7d0d", "metadata": {}, "source": [ "### モデルと推論コードを保存するディレクトリを作成" ] }, { "cell_type": "code", "execution_count": null, "id": "c6713f45", "metadata": {}, "outputs": [], "source": [ "model_dir: Final[str] = 'model'\n", "!if [ -d ./{model_dir} ]; then rm -rf ./{model_dir}/;fi\n", "!mkdir ./{model_dir}/\n", "\n", "source_dir: Final[str] = 'source'\n", "!if [ -d ./{source_dir} ]; then rm -rf ./{source_dir}/;fi\n", "!mkdir ./{source_dir}/" ] }, { "cell_type": "markdown", "id": "94ab3d8a", "metadata": {}, "source": [ "### モデル相当のテキストファイルを `tar.gz` で固めて S3 にアップロードする\n", "* SageMaker で推論する場合は機械学習のモデルを `model.tar.gz` に固めておく必要がある\n", " * SageMaker Training を使ってモデルを保存した場合は自動で tar.gz になるが、このチュートリアルでは Training Job を使わないため、手動で tar.gz に固める\n", " * 機械学習のモデルと言ったが、用意したファイルを読み込むコードを書き、その読み込んだデータを使って処理を行うだけなので必ずしも機械学習のモデルである必要はない\n", " * このチュートリアルでは Hello my great machine learning model と書かれたテキストファイル(`my_model.txt`)を作成して、`tar.gz` に固める\n", "* `model.tar.gz` を推論環境で使うには予め S3 にアップロードしておく必要がある" ] }, { "cell_type": "markdown", "id": "ddc97969", "metadata": {}, "source": [ "#### `my_model.txt` 作成" ] }, { "cell_type": "code", "execution_count": null, "id": "d9f2b3fb", "metadata": {}, "outputs": [], "source": [ "%%writefile ./{model_dir}/my_model.txt\n", "Hello my great machine learning model" ] }, { "cell_type": "markdown", "id": "93d9c52b", "metadata": {}, "source": [ "#### `my_model.txt` を `model.tar.gz` に固める" ] }, { "cell_type": "code", "execution_count": null, "id": "4eefcc94", "metadata": {}, "outputs": [], "source": [ "%cd {model_dir}\n", "!tar zcvf model.tar.gz ./*\n", "%cd .." ] }, { "cell_type": "markdown", "id": "b6cda830", "metadata": {}, "source": [ "#### `model.tar.gz` を S3 にアップロード" ] }, { "cell_type": "code", "execution_count": null, "id": "634aeb6e", "metadata": {}, "outputs": [], "source": [ "model_s3_uri:Final[str] = sagemaker.session.Session().upload_data(\n", " f'./{model_dir}/model.tar.gz',\n", " key_prefix = 'hello_sagemaker_inference'\n", ")\n", "print(model_s3_uri)" ] }, { "cell_type": "markdown", "id": "4b327c0c", "metadata": {}, "source": [ "### 推論コードを作成する\n", "* 最低限 `model_fn` と `predict_fn` が必要\n", "* `model_fn` は `model.tar.gz` に固めたモデルを読み込むコード\n", " * 第一引数に `model.tar.gz` を展開したディレクトリが入る\n", " * 返り値にモデルを返すと、`predict_fn` の第二引数に入れられる\n", "* `predict_fn` は推論コード\n", " * 第一引数にリクエスト(推論したいデータ)が入る\n", " * 第二引数に model_fn の返り値が入る\n", " * 推論結果を返り値に入れる" ] }, { "cell_type": "code", "execution_count": null, "id": "c90fe886", "metadata": {}, "outputs": [], "source": [ "%%writefile ./{source_dir}/inference.py\n", "import os\n", "def model_fn(model_dir):\n", " with open(os.path.join(model_dir,'my_model.txt')) as f:\n", " model = f.read()[:-1] # 改行を除外\n", " return model\n", "def predict_fn(input_data, model):\n", " response = f'{model} for the {input_data}st time'\n", " return response" ] }, { "cell_type": "markdown", "id": "0b991cb5", "metadata": {}, "source": [ "## SageMaker SDK でモデルをデプロイしてリアルタイム推論\n", "### SageMaker SDK を用いてモデルをデプロイ\n", "SageMaker SDK でモデルをデプロイするには、[Model](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model) クラスでモデルを定義する必要がある \n", "今回は AWS が管理・公開している PyTorch のコンテナを使うため、`Model` を継承した [PyTorchModel](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/sagemaker.pytorch.html#sagemaker.pytorch.model.PyTorchModel) クラスを使用する。 \n", "`PyTorchModel` では、モデルにつける任意の名前、使用するモデルの S3 の URI、フレームワークや Python のバージョン、推論コードなどを指定する。\n", "PyTorchModel でインスタンスを生成したら、[deploy](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy) メソッドでモデルをデプロイできる。デプロイ時はインスタンスタイプと台数、エンドポイントにつける任意の名前を設定する。" ] }, { "cell_type": "code", "execution_count": null, "id": "84481baa", "metadata": {}, "outputs": [], "source": [ "# 名前の設定\n", "model_name: Final[str] = 'PyTorchModel'\n", "endpoint_name: Final[str] = model_name + 'Endpoint'" ] }, { "cell_type": "code", "execution_count": null, "id": "815c582a", "metadata": {}, "outputs": [], "source": [ "# モデルとコンテナの指定\n", "pytorch_model = PyTorchModel(\n", " name = model_name,\n", " model_data=model_s3_uri,\n", " role= role,\n", " framework_version = '1.11.0',\n", " py_version='py38',\n", " entry_point='inference.py',\n", " source_dir=f'./{source_dir}/'\n", ")\n", "pytorch_predictor = pytorch_model.deploy(\n", " initial_instance_count=1,\n", " instance_type='ml.m5.large',\n", " enpoint_name=endpoint_name\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6f0dd52c", "metadata": {}, "outputs": [], "source": [ "response = pytorch_predictor.predict(1)\n", "print(response,type(response))" ] }, { "cell_type": "markdown", "id": "2b03fc98", "metadata": {}, "source": [ "リクエストはできたがレスポンスがなぜか numpy array である。 \n", "理由は [PyTorchPredictor](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/sagemaker.pytorch.html#sagemaker.pytorch.model.PyTorchPredictor) に serializer , desirializer がデフォルトで設定されており、`predict` メソッドでエンドポイントにリクエストする前にリクエストデータ(↑の例では int 型の1)を numpy array に Serialize してリクエストし、レスポンスを受け取った後にレスポンスデータを numpy array に Desiarlize するため。" ] }, { "cell_type": "code", "execution_count": null, "id": "a667c42a", "metadata": {}, "outputs": [], "source": [ "print(pytorch_predictor.serializer, pytorch_predictor.deserializer)" ] }, { "cell_type": "code", "execution_count": null, "id": "de0d3bf4", "metadata": {}, "outputs": [], "source": [ "# endpointとモデルを削除\n", "pytorch_predictor.delete_endpoint()\n", "pytorch_model.delete_model()" ] }, { "cell_type": "markdown", "id": "32d03d68", "metadata": {}, "source": [ "## Boto3 でリアルタイム推論\n", "serializer/desirializer は SageMaker SDK の機能で、推論エンドポイントに推論データをリクエストする環境(AWS Lambda など)には入っていないことが多い(boto3でやることが多い)。また、推論エンドポイント立ち上げもパイプラインに組み込む際は SageMaker SDKを使わない環境もありえる。 \n", "一連の流れを Boto3 で実行してみてSerializer/Deserializerが無い場合の挙動を確認する。" ] }, { "cell_type": "markdown", "id": "13730b45", "metadata": {}, "source": [ "### 推論コードを model.tar.gz に固めて S3 にアップロード\n", "* pytorch の場合は推論コードを model.tar.gz に内包する必要がある\n", "* SageMaker SDK では `deploy` メソッド実行時に裏側で推論コードを `model.tar.gz` に固めてアップロードしてくれていた\n", "* boto3 でモデルをデプロイする場合は手動で `tar.gz` で固めて S3 にアップロードする必要がある" ] }, { "cell_type": "code", "execution_count": null, "id": "b3d1aabe", "metadata": {}, "outputs": [], "source": [ "%cd {model_dir}\n", "!rm model.tar.gz\n", "!cp ../{source_dir}/inference.py ./\n", "!tar zcvf model.tar.gz ./*\n", "%cd .." ] }, { "cell_type": "code", "execution_count": null, "id": "abd8f173", "metadata": {}, "outputs": [], "source": [ "source_s3_uri:Final[str] = sagemaker.session.Session().upload_data(\n", " f'./{model_dir}/model.tar.gz',\n", " key_prefix = 'hello_sagemaker_inference'\n", ")\n", "print(source_s3_uri)" ] }, { "cell_type": "markdown", "id": "3055aaab", "metadata": {}, "source": [ "### EndpointConfigName 設定\n", "SageMaker SDK では `deploy` メソッド実行時に自動で Model と同じ名前で EndpointConfig を作成するが、Boto3 は明示的に作成する必要がある。" ] }, { "cell_type": "code", "execution_count": null, "id": "8720f1c8", "metadata": {}, "outputs": [], "source": [ "endpoint_config_name: Final[str] = model_name + 'EndpointConfig'" ] }, { "cell_type": "markdown", "id": "9c9fafe6", "metadata": {}, "source": [ "### モデル作成、エンドポイントコンフィグ作成、エンドポイント作成\n", "1. [create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) でモデルと推論環境(推論コードやコンテナイメージ、環境変数の設定)をパッケージ化した Model を作成する\n", "2. [create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) で使用する Model や推論に使うコンピューティングリソース(インスタンスタイプ、台数など)や負荷の配分を設定する\n", "3. [create_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint) で EndpointConfig で設定した内容をデプロイする" ] }, { "cell_type": "code", "execution_count": null, "id": "464fe196", "metadata": {}, "outputs": [], "source": [ "# コンテナイメージの URI を取得\n", "container_image_uri: Final[str] = sagemaker.image_uris.retrieve(\n", " \"pytorch\", \n", " sagemaker.session.Session().boto_region_name, # ECR のリージョンを指定\n", " version='1.11.0', # SKLearn のバージョンを指定\n", " instance_type = 'ml.m5.large', # インスタンスタイプを指定\n", " image_scope = 'inference' # 推論コンテナを指定\n", ")\n", "print(container_image_uri)" ] }, { "cell_type": "markdown", "id": "7a282306", "metadata": {}, "source": [ "create_endpoint は非同期 API で、すぐにレスポンスを返すが裏側ではエンドポイントを作成している。[endpoint_inservice_waiter.wait](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.get_waiter) でエンドポイント作成完了を待つことができる。(数分かかる)" ] }, { "cell_type": "code", "execution_count": null, "id": "9abdcb19", "metadata": {}, "outputs": [], "source": [ "# Model 作成\n", "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " 'Image': container_image_uri,\n", " 'ModelDataUrl': model_s3_uri,\n", " 'Environment': {\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_REGION': region,\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code'}\n", " },\n", " ExecutionRoleArn=role,\n", ")\n", "# EndpointConfig 作成\n", "response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " 'VariantName': 'AllTrafic',\n", " 'ModelName': model_name,\n", " 'InitialInstanceCount': 1,\n", " 'InstanceType': 'ml.m5.large',\n", " },\n", " ],\n", ")\n", "# Endpoint 作成\n", "response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")\n", "# Endpoint が有効化されるまで待つ\n", "endpoint_inservice_waiter.wait(\n", " EndpointName=endpoint_name,\n", " WaiterConfig={'Delay': 5,}\n", ")" ] }, { "cell_type": "markdown", "id": "9d53a487", "metadata": {}, "source": [ "一般的な`application/json`のヘッダでリクエストしてみる" ] }, { "cell_type": "code", "execution_count": null, "id": "22945c3e", "metadata": {}, "outputs": [], "source": [ "response = smr_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType='application/json',\n", " Accept='application/json',\n", " Body='1'\n", ")\n", "predictions = json.loads(response['Body'].read().decode('utf-8'))\n", "print(predictions)" ] }, { "cell_type": "markdown", "id": "d8902dda", "metadata": {}, "source": [ "なぜか torch tensor が表示されている。なぜ? \n", "[default_input_fn](https://github.com/aws/sagemaker-scikit-learn-container/blob/7773e19bf0df6bdd65f10076ff7e8ecc1390cb9b/src/sagemaker_sklearn_container/handler_service.py#L47) が影響している" ] }, { "cell_type": "code", "execution_count": null, "id": "20f2853b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# default_input_fn 確認\n", "!curl -s https://raw.githubusercontent.com/aws/deep-learning-containers/master/pytorch/inference/docker/build_artifacts/default_inference_handler.py | pygmentize" ] }, { "cell_type": "code", "execution_count": null, "id": "9dd0f315", "metadata": {}, "outputs": [], "source": [ "# default_input_fn が呼ぶ decode の確認\n", "!curl -s https://raw.githubusercontent.com/aws/sagemaker-inference-toolkit/master/src/sagemaker_inference/decoder.py | pygmentize" ] }, { "cell_type": "code", "execution_count": null, "id": "5fced00c", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# content_types の確認\n", "!curl -s https://raw.githubusercontent.com/aws/sagemaker-inference-toolkit/master/src/sagemaker_inference/content_types.py | pygmentize" ] }, { "cell_type": "markdown", "id": "f83815b5", "metadata": {}, "source": [ "request header が `application/json` だったら `torch.FloatTensor(np.array(json.loads(input_data), dtype=None))` していたため、小数の torch tensor になっていた。 \n", "上記を踏まえ、SageMaker SDK の [NumpySerializer](https://github.com/aws/sagemaker-python-sdk/blob/bd8ea409ae91b07ac148520f7631fba9feee0069/src/sagemaker/serializers.py#L148) の [_serialize_array()](https://github.com/aws/sagemaker-python-sdk/blob/bd8ea409ae91b07ac148520f7631fba9feee0069/src/sagemaker/serializers.py#L188) 相当を手動で行い、ヘッダ `application/x-npy` でリクエストしてみると、torch tensor にならずに済む" ] }, { "cell_type": "code", "execution_count": null, "id": "c0c2cf7d", "metadata": {}, "outputs": [], "source": [ "buffer = BytesIO()\n", "np.save(buffer,np.array(1))\n", "response = smr_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType='application/x-npy',\n", " Accept='application/json',\n", " Body=buffer.getvalue(),\n", ")\n", "predictions = json.loads(response['Body'].read().decode('utf-8'))\n", "print(predictions)" ] }, { "cell_type": "markdown", "id": "d95d076b", "metadata": {}, "source": [ "### Model, EndpointConfig, Endpoint を削除" ] }, { "cell_type": "code", "execution_count": null, "id": "35234916", "metadata": {}, "outputs": [], "source": [ "sm_client.delete_endpoint(EndpointName=endpoint_name)\n", "sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "sm_client.delete_model(ModelName=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "e40b3cd3", "metadata": {}, "outputs": [], "source": [ "# 削除が完了するまで待つ\n", "sleep(5)" ] }, { "cell_type": "markdown", "id": "4726507c", "metadata": {}, "source": [ "## 前処理と後処理の追加\n", "### 前処理と後処理の関数作成\n", "推論コードに `input_fn` と `output_fn` を記述すると、`default_input_fn` や `default_output_fn` は使われずにユーザーが作成した input_fn や output_fn が使われるようになる" ] }, { "cell_type": "code", "execution_count": null, "id": "a838a5e4", "metadata": {}, "outputs": [], "source": [ "%%writefile ./{source_dir}/inference.py\n", "import os, json\n", "def model_fn(model_dir):\n", " with open(os.path.join(model_dir,'my_model.txt')) as f:\n", " hello = f.read()[:-1] # 改行を除外\n", " return hello\n", "def input_fn(input_data, content_type):\n", " if content_type == 'text/csv':\n", " transformed_data = input_data.split(',')\n", " else:\n", " raise ValueError(f\"Illegal content type {content_type}. The only allowed content_type is text/csv\")\n", " print(input_data,transformed_data)\n", " return transformed_data\n", "def predict_fn(transformed_data, model):\n", " prediction_list = []\n", " for data in transformed_data:\n", " if data[-1] == '1':\n", " ordinal = f'{data}st'\n", " elif data[-1] == '2':\n", " ordinal = f'{data}nd'\n", " elif data[-1] == '3':\n", " ordinal = f'{data}rd'\n", " else:\n", " ordinal = f'{data}th'\n", " prediction = f'{model} for the {ordinal} time'\n", " prediction_list.append(prediction)\n", " print(transformed_data,prediction_list) \n", " return prediction_list\n", "def output_fn(prediction_list, accept):\n", " if accept == 'text/csv': \n", " response = ''\n", " for prediction in prediction_list:\n", " response += prediction + '\\n'\n", " print(prediction_list,response)\n", " else:\n", " raise ValueError(f\"Illegal accept type {accept}. The only allowed accept type is text/csv\")\n", " return response, accept" ] }, { "cell_type": "markdown", "id": "4a67a7c7", "metadata": {}, "source": [ "### 前処理と後処理のコードをモデルと一緒に `model.tar.gz ` で固めて S3 にアップロード" ] }, { "cell_type": "code", "execution_count": null, "id": "caf0bb05", "metadata": {}, "outputs": [], "source": [ "%cd {model_dir}\n", "!rm model.tar.gz\n", "!cp ../{source_dir}/inference.py ./\n", "!tar zcvf model.tar.gz ./*\n", "%cd .." ] }, { "cell_type": "code", "execution_count": null, "id": "dc1c8170", "metadata": {}, "outputs": [], "source": [ "source_s3_uri:Final[str] = sagemaker.session.Session().upload_data(\n", " f'./{model_dir}/model.tar.gz',\n", " key_prefix = 'hello_sagemaker_inference'\n", ")\n", "print(source_s3_uri)" ] }, { "cell_type": "markdown", "id": "b75daf66", "metadata": {}, "source": [ "### Model, EndpointConfig, Endpoint を作成" ] }, { "cell_type": "code", "execution_count": null, "id": "645c43f3", "metadata": {}, "outputs": [], "source": [ "# Model 作成\n", "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " 'Image': container_image_uri,\n", " 'ModelDataUrl': model_s3_uri,\n", " 'Environment': {\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_REGION': region,\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code'}\n", " },\n", " ExecutionRoleArn=role,\n", ")\n", "# EndpointConfig 作成\n", "response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " 'VariantName': 'AllTrafic',\n", " 'ModelName': model_name,\n", " 'InitialInstanceCount': 1,\n", " 'InstanceType': 'ml.m5.large',\n", " },\n", " ],\n", ")\n", "# Endpoint 作成\n", "response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")\n", "# Endpoint が有効化されるまで待つ\n", "endpoint_inservice_waiter.wait(\n", " EndpointName=endpoint_name,\n", " WaiterConfig={'Delay': 5,}\n", ")" ] }, { "cell_type": "markdown", "id": "1dab278e", "metadata": {}, "source": [ "### 推論" ] }, { "cell_type": "code", "execution_count": null, "id": "3ee02ebd", "metadata": {}, "outputs": [], "source": [ "response = smr_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType='text/csv',\n", " Accept='text/csv',\n", " Body='1,2,3,10000'\n", ")\n", "predictions = response['Body'].read().decode('utf-8')\n", "print(predictions)" ] }, { "cell_type": "markdown", "id": "1dc4138b", "metadata": {}, "source": [ "### Model, EndpointConfig, Endpoint を削除" ] }, { "cell_type": "code", "execution_count": null, "id": "7133de30", "metadata": {}, "outputs": [], "source": [ "sm_client.delete_endpoint(EndpointName=endpoint_name)\n", "sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "sm_client.delete_model(ModelName=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "043c9ef2", "metadata": {}, "outputs": [], "source": [ "sleep(5)" ] }, { "cell_type": "markdown", "id": "4dce147d", "metadata": {}, "source": [ "## 非同期推論\n", "* 非同期推論は推論データを S3 に配置し、推論するときは S3 のどこに推論データがあるのかを引数に入れる\n", "* 推論結果はレスポンスにある S3 の URI に格納されるが、レスポンスされたタイミングでは推論結果が置かれている保証はなく、処理が終わり次第配置される\n", "* [create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) の AsyncInferenceConfig 引数を設定することで非同期推論エンドポイントが立ち上がる\n", "\n", "### エンドポイント作成" ] }, { "cell_type": "code", "execution_count": null, "id": "c83102c4", "metadata": {}, "outputs": [], "source": [ "# Model 作成\n", "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " 'Image': container_image_uri,\n", " 'ModelDataUrl': model_s3_uri,\n", " 'Environment': {\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_REGION': region,\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code'}\n", " },\n", " ExecutionRoleArn=role,\n", ")\n", "# EndpointConfig 作成\n", "response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " 'VariantName': 'AllTrafic',\n", " 'ModelName': model_name,\n", " 'InitialInstanceCount': 1,\n", " 'InstanceType': 'ml.m5.large',\n", " },\n", " ],\n", " AsyncInferenceConfig={\n", " \"OutputConfig\": {\n", " \"S3OutputPath\": f\"s3://{bucket}/hello_sagemaker_inference/async_inference/output\"\n", " },\n", " }\n", ")\n", "# Endpoint 作成\n", "response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")\n", "# Endpoint が有効化されるまで待つ\n", "endpoint_inservice_waiter.wait(\n", " EndpointName=endpoint_name,\n", " WaiterConfig={'Delay': 5,}\n", ")\n" ] }, { "cell_type": "markdown", "id": "5317c598", "metadata": {}, "source": [ "### 推論データ作成" ] }, { "cell_type": "code", "execution_count": null, "id": "f03e2f08", "metadata": {}, "outputs": [], "source": [ "input_data: Final[str] = './input_data.csv'\n", "with open(input_data,'wt') as f:\n", " f.write('2,3,4,1000')" ] }, { "cell_type": "markdown", "id": "35884254", "metadata": {}, "source": [ "### 推論データを S3 にアップロード" ] }, { "cell_type": "code", "execution_count": null, "id": "02da2036", "metadata": {}, "outputs": [], "source": [ "input_data_s3_uri:Final[str] = sagemaker.Session().upload_data(\n", " './input_data.csv',\n", " key_prefix = 'hello_sagemaker_inference/async_inference'\n", ")\n", "print(input_data_s3_uri)" ] }, { "cell_type": "markdown", "id": "5baaa9b1", "metadata": {}, "source": [ "### 推論と推論結果を取得\n", "推論には [invoke_endpoint_async](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime.html#SageMakerRuntime.Client.invoke_endpoint_async) を使う" ] }, { "cell_type": "code", "execution_count": null, "id": "2a753e92", "metadata": {}, "outputs": [], "source": [ "response = smr_client.invoke_endpoint_async(\n", " EndpointName=endpoint_name, \n", " InputLocation=input_data_s3_uri,\n", " ContentType='text/csv',\n", " Accept='text/csv',\n", ")\n", "output_s3_uri = response['OutputLocation']\n", "output_key = output_s3_uri.replace(f's3://{bucket}/','')\n", "while True:\n", " result = s3_client.list_objects(Bucket=bucket, Prefix=output_key)\n", " exists = True if \"Contents\" in result else False\n", " if exists:\n", " print('!')\n", " obj = s3_client.get_object(Bucket=bucket, Key=output_key)\n", " predictions = obj['Body'].read().decode()\n", " print(predictions)\n", " break\n", " else:\n", " print('.',end='')\n", " sleep(0.1)" ] }, { "cell_type": "markdown", "id": "5cdf50f5", "metadata": {}, "source": [ "### リソース削除" ] }, { "cell_type": "code", "execution_count": null, "id": "0689de26", "metadata": {}, "outputs": [], "source": [ "sm_client.delete_endpoint(EndpointName=endpoint_name)\n", "sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "sm_client.delete_model(ModelName=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "5d22db91", "metadata": {}, "outputs": [], "source": [ "sleep(5)" ] }, { "cell_type": "markdown", "id": "a599e237", "metadata": {}, "source": [ "## サーバーレス推論\n", "* サーバーレス推論は、コンピューティングリソースをプロビジョンせず、推論が発生している時間に対して課金する推論方法\n", "* サーバーレス推論の推論エンドポイントの立ち上げ方は [create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) の Variant 内の ServerlessConfig で設定する" ] }, { "cell_type": "code", "execution_count": null, "id": "f5d33b72", "metadata": {}, "outputs": [], "source": [ "# Model 作成\n", "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " 'Image': container_image_uri,\n", " 'ModelDataUrl': model_s3_uri,\n", " 'Environment': {\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_REGION': region,\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code'}\n", " },\n", " ExecutionRoleArn=role,\n", ")\n", "# EndpointConfig 作成\n", "response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " 'ModelName': model_name,\n", " 'VariantName': 'AllTrafic',\n", " 'ServerlessConfig': { \n", " 'MemorySizeInMB': 1024, \n", " 'MaxConcurrency': 3\n", " }\n", " },\n", " ],\n", ")\n", "# Endpoint 作成\n", "response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")\n", "# Endpoint が有効化されるまで待つ\n", "endpoint_inservice_waiter.wait(\n", " EndpointName=endpoint_name,\n", " WaiterConfig={'Delay': 5,}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "ab91a722", "metadata": {}, "outputs": [], "source": [ "response = smr_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType='text/csv',\n", " Accept='text/csv',\n", " Body='1,2,3,10000'\n", ")\n", "predictions = response['Body'].read().decode('utf-8')\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "id": "d9635941", "metadata": {}, "outputs": [], "source": [ "sm_client.delete_endpoint(EndpointName=endpoint_name)\n", "sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "sm_client.delete_model(ModelName=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "c44605ac", "metadata": {}, "outputs": [], "source": [ "sleep(5)" ] }, { "cell_type": "markdown", "id": "871ba936", "metadata": {}, "source": [ "## バッチ推論\n", "* バッチ推論は溜まったデータをまとめて推論するコスト効率が良い方法で、レイテンシーを求められない時に使用する\n", "* バッチ推論は [create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) でモデルを作成したあと、create_transform_job でジョブを作成して推論を行う\n", "* 推論結果は S3 に出力される" ] }, { "cell_type": "markdown", "id": "207d4b46", "metadata": {}, "source": [ "### 推論データを複数作成" ] }, { "cell_type": "code", "execution_count": null, "id": "2bab7514", "metadata": { "scrolled": true }, "outputs": [], "source": [ "!mkdir -p batch\n", "batch_data_dir: Final[str] = './batch'\n", "input_data1: Final[str] = 'input_data1.csv'\n", "input_data2: Final[str] = './input_data2.csv'\n", "with open(os.path.join(batch_data_dir,input_data1),'wt') as f:\n", " f.write('3,4,5,100')\n", "with open(os.path.join(batch_data_dir,input_data2),'wt') as f:\n", " f.write('9,8,7,6,5')\n", "!aws s3 rm --recursive s3://{bucket}/{prefix}\n", "prefix:Final[str] = 'hello_sagemaker_inference/transform_job'\n", "input_prefix:Final[str] = prefix + '/input'\n", "output_prefix:Final[str] = prefix + '/output'\n", "input_data_s3_uri:Final[str] = sagemaker.Session().upload_data(batch_data_dir,key_prefix = input_prefix)\n", "print(input_data_s3_uri)" ] }, { "cell_type": "markdown", "id": "7d16e668", "metadata": {}, "source": [ "### モデル作成" ] }, { "cell_type": "code", "execution_count": null, "id": "3af27428", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Model 作成\n", "response = sm_client.create_model(\n", " ModelName=model_name,\n", " PrimaryContainer={\n", " 'Image': container_image_uri,\n", " 'ModelDataUrl': model_s3_uri,\n", " 'Environment': {\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_REGION': region,\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code'}\n", " },\n", " ExecutionRoleArn=role,\n", ")" ] }, { "cell_type": "markdown", "id": "d9dbca45", "metadata": {}, "source": [ "### 推論ジョブ作成" ] }, { "cell_type": "code", "execution_count": null, "id": "af54cf94", "metadata": {}, "outputs": [], "source": [ "transform_job_name: Final[str] = f'{model_name}TransformJob-{uuid4()}'\n", "print(transform_job_name)\n", "response = sm_client.create_transform_job(\n", " TransformJobName=transform_job_name,\n", " ModelName=model_name,\n", " TransformInput={\n", " 'DataSource': {\n", " 'S3DataSource': {\n", " 'S3DataType': 'S3Prefix',\n", " 'S3Uri': f's3://{bucket}/{input_prefix}'\n", " }\n", " },\n", " 'ContentType': 'text/csv',\n", " },\n", " TransformOutput={\n", " 'S3OutputPath': f's3://{bucket}/{output_prefix}',\n", " 'Accept': 'text/csv',\n", " },\n", " TransformResources={\n", " 'InstanceType': 'ml.m5.large',\n", " 'InstanceCount': 1,\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "22dd545a", "metadata": {}, "source": [ "### 推論結果を取得" ] }, { "cell_type": "code", "execution_count": null, "id": "801c6b3e", "metadata": {}, "outputs": [], "source": [ "while True:\n", " if sm_client.describe_transform_job(TransformJobName=transform_job_name)['TransformJobStatus'] == 'Completed':\n", " print('!')\n", " for content in s3_client.list_objects_v2(Bucket=bucket,Prefix=output_prefix)['Contents']:\n", " obj = s3_client.get_object(Bucket=bucket, Key=content['Key'])\n", " predictions = obj['Body'].read().decode()\n", " print(predictions)\n", " break\n", " else:\n", " print('.',end='')\n", " sleep(5)" ] }, { "cell_type": "markdown", "id": "88052c48", "metadata": {}, "source": [ "## 全て削除" ] }, { "cell_type": "code", "execution_count": null, "id": "0e70e960", "metadata": {}, "outputs": [], "source": [ "# for endpoint in sm_client.list_endpoints()['Endpoints']:\n", "# response = sm_client.delete_endpoint(EndpointName=endpoint['EndpointName'])\n", "# print(response)" ] }, { "cell_type": "code", "execution_count": null, "id": "c11e076e", "metadata": {}, "outputs": [], "source": [ "# for endpoint_config in sm_client.list_endpoint_configs()['EndpointConfigs']:\n", "# response = sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config['EndpointConfigName'])\n", "# print(response)" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf14cf3", "metadata": {}, "outputs": [], "source": [ "# for model in sm_client.list_models()['Models']:\n", "# response = sm_client.delete_model(ModelName=model['ModelName'])\n", "# print(response)" ] }, { "cell_type": "code", "execution_count": null, "id": "a55b0e24", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }