{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Amazon SageMaker Model Monitor と Debugger を使って不正な予測を検知して分析する\n", "\n", "### 依存ライブラリのインストール\n", "\n", "下のセルを実行して依存ライブラリをインストールしてカーネルを再起動してください" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "! pip install imageio opencv-python 'sagemaker>=2,<3' smdebug torchvision\n", "! git clone https://github.com/advboxes/AdvBox advbox\n", "! cd advbox; python setup.py build; python setup.py install\n", "! pip install future" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "モジュールをImportできるかテストします。\n", "\n", "エラーが発生した場合はカーネルを再起動してみてください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import advbox" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### モデルをアーカイブしてAmazon S3にアップロードします\n", "\n", "このノートブックは、43のクラスで構成される [German Traffic Sign dataset](http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset)で学習されたResNet18モデルを使用します。 モデルをSageMakerにデプロイする前に、その重みをアーカイブしてAmazonS3にアップロードする必要があります。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir model\n", "!wget https://github.com/aws-samples/amazon-sagemaker-analyze-model-predictions/raw/master/model/model.pt -O model/model.pt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tarfile\n", "\n", "with tarfile.open('model.tar.gz', mode='w:gz') as archive:\n", " archive.add('model', recursive=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "\n", "boto_session = boto3.Session()\n", "sagemaker_session = sagemaker.Session(boto_session=boto_session)\n", "\n", "inputs = sagemaker_session.upload_data(path='model.tar.gz', key_prefix='model')\n", "print(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorchモデルとデプロイスクリプトを定義\n", "\n", "[SageMaker hosting services](https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/how-it-works-deployment.html)を使用して、モデルから予測を取得するための永続的なエンドポイントを設定します。 このためにmodel_data引数にモデルをアーカイブしたS3 PathをとるPyTorchModelオブジェクトを定義します。\n", "entry_pointには、model_fn関数とtransform_fn関数が含まれるpretrained_model.pyを定義します。これらの関数はホスティング中に使用され、モデルが推論コンテナ内で正しく読み込まれ、リクエストを適切に処理できるようにします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "\n", "role = sagemaker.get_execution_role()\n", "\n", "sagemaker_model = PyTorchModel(\n", " model_data=f's3://{sagemaker_session.default_bucket()}/model/model.tar.gz',\n", " role=role,\n", " source_dir='code',\n", " entry_point='pretrained_model.py',\n", " framework_version='1.3.1',\n", " py_version='py3',\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`pretrained_model.py`を見てみましょう" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pycat code/pretrained_model.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SageMaker Model Monitorのセットアップとモデルのデプロイ\n", "\n", "[SageMaker Model Monitor](https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/model-monitor.html)は、本番環境の機械学習モデルを自動的に監視し、データ品質の問題を検出するとアラートを出します。 エンドポイントの入力と出力をキャプチャし、後でModel Monitorが収集したデータとモデルの予測を検査できるように監視スケジュールを作成します。\n", "\n", "DataCaptureConfig APIは、Model Monitorが出力先のAmazon S3バケットに保存する入力と出力の割合を指定します。 この例では、サンプリングの割合が50%に設定されています。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.model_monitor import DataCaptureConfig\n", "\n", "data_capture_config = DataCaptureConfig(\n", " enable_capture=True,\n", " sampling_percentage=50,\n", " destination_s3_uri=f's3://{sagemaker_session.default_bucket()}/endpoint/data_capture',\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これで、エンドポイントを`ml.m5.xlarge`インスタンスにデプロイする準備が整いました。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = sagemaker_model.deploy(\n", " initial_instance_count=1,\n", " instance_type='ml.m5.xlarge',\n", " data_capture_config=data_capture_config,\n", " # エンドポイントは、デフォルトで期待されるnumpyではなくJSONを返します\n", " deserializer=sagemaker.deserializers.JSONDeserializer(),\n", ")\n", "\n", "endpoint_name = predictor.endpoint_name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 正常のテストデータで推論を実行する\n", "\n", "推論を実行する前に、[German Traffic Sign dataset](http://benchmark.ini.rub.de/?section=gtsrb&subsection=news) のテスト画像をダウンロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! wget -N https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip\n", "! unzip -oq GTSRB_Final_Test_Images.zip\n", "! wget -N https://raw.githubusercontent.com/aditbiswas1/P2-traffic-sign-classifier/master/signnames.csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "画像クラスの名前をロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pandas.io.parsers import read_csv\n", "\n", "signnames = read_csv('signnames.csv').values[:, 1]\n", "signnames" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "次に、シリアル化された入力画像を含むペイロードを使用してエンドポイントを呼び出します。 エンドポイントは、transform_fn関数を呼び出して、推論を実行する前にデータを前処理します。エンドポイントはjson文字列にエンコードされた整数のlistとして、画像ストリームの予測クラスを返します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "from PIL import Image\n", "import glob\n", "import matplotlib.pyplot as plt\n", "\n", "ncols = 4 # Plot in multiple columns to save some space\n", "\n", "for index, file in enumerate(glob.glob('GTSRB/Final_Test/Images/*ppm')):\n", " if index >= 20:\n", " break\n", "\n", " # Load image file to array:\n", " image = Image.open(file)\n", "\n", " # Invoke the endpoint: (Returns a 2D 1x1 array)\n", " result = predictor.predict(image)\n", "\n", " # Plot the results:\n", " ixcol = index % ncols\n", " if (ixcol == 0):\n", " plt.show()\n", " fig = plt.figure(figsize=(ncols*4.5, 4.5))\n", " plt.subplot(1, ncols, ixcol + 1)\n", " plt.title(signnames[result[0][0]])\n", " plt.imshow(image)\n", " plt.axis('off')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SageMaker Model Monitorのスケジュールを定義する\n", "\n", "次に、SageMaker Model Monitorを使用して、ベースラインを設定し、監視スケジュールを設定する方法について説明します。 \n", "Model Monitorは、制約条件の設定と平均、分位数、標準偏差のような統計量を計算する[baseline](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-create-baseline.html)という組み込みコンテナを提供しています。[monitoring schedule](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-scheduling.html)を起動すると、収集されたデータを検査して、指定された制約条件と比較し、違反している場合にレポートを生成するProcessingジョブを定期的に実行できます。\n", "\n", "この例では、単純なモデルのサニティチェックのみを実行するカスタムコンテナを作成します。[evaluationスクリプト](./docker/evaluation.py)は、予測された画像クラスを単純にカウントします。モデルが特定の道路標識を他のクラスよりも頻繁に予測するような場合、または信頼スコアが一貫して低い場合に問題が起きていると捉えます。ここでは特定の画像クラスが50%以上の確率で予測される場合にアラートをあげるようにします。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "まず、カスタムコンテナを作成する必要があります。 [dockerfile](./docker/Dockerfile)は、 evaluationスクリプトをエントリポイントファイルとして受け取ります。次のコードセルは、Dockerコンテナをビルドし、Amazon ECRにアップロードします。\n", "\n", "**このノートブックをSageMaker Studio内で実行する場合、`docker build`は機能しないため、`docker`コマンドを提供するインスタンスでコマンドを実行する必要があります。**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "ecr_repository = 'sagemaker-processing-container'\n", "tag = ':latest'\n", "\n", "region = boto3.session.Session().region_name\n", "\n", "uri_suffix = 'amazonaws.com'\n", "if region in ['cn-north-1', 'cn-northwest-1']:\n", " uri_suffix = 'amazonaws.com.cn'\n", "processing_repository_uri = f'{account_id}.dkr.ecr.{region}.{uri_suffix}/{ecr_repository + tag}'\n", "\n", "# Create ECR repository and push docker image\n", "!docker build -t $ecr_repository docker\n", "!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)\n", "!aws ecr create-repository --repository-name $ecr_repository\n", "!docker tag {ecr_repository + tag} $processing_repository_uri\n", "!docker push $processing_repository_uri" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SageMakerモデルモニターを定義します。DockerイメージのURIとevaluationスクリプトに必要な環境変数を指定します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.model_monitor import ModelMonitor\n", "\n", "monitor = ModelMonitor(\n", " role=role,\n", " image_uri=processing_repository_uri,\n", " instance_count=1,\n", " instance_type='ml.m5.large',\n", " env={ 'THRESHOLD':'0.5' },\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "次に、[Model Monitor Schedule](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-scheduling.html)を定義してエンドポイントにアタッチします。 このカスタムコンテナは1時間ごとに実行されます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.model_monitor import CronExpressionGenerator, MonitoringOutput\n", "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", "\n", "destination = f's3://{sagemaker_session.default_bucket()}/endpoint/monitoring_schedule'\n", "processing_output = ProcessingOutput(output_name='model_outputs',\n", " source='/opt/ml/processing/outputs',\n", " destination=destination,\n", " )\n", "output = MonitoringOutput(source=processing_output.source, \n", " destination=processing_output.destination)\n", "\n", "monitor.create_monitoring_schedule(\n", " output=output,\n", " endpoint_input=predictor.endpoint_name,\n", " schedule_cron_expression=CronExpressionGenerator.hourly(), # 1hごとに実行\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "監視スケジュールを見てみましょう" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "monitor.describe_schedule()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SageMaker Model Monitorは、1時間ごとにProcessingジョブを実行します。 これらのProcessingジョブの実行を一覧表示できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jobs = monitor.list_executions()\n", "jobs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Processingジョブの詳細にアクセスできます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if len(jobs) > 0:\n", " print(monitor.list_executions()[-1].describe())\n", "else:\n", " print(\"\"\"No processing job has been executed yet. \n", " This means that one hour has not passed yet. \n", " You can go to the next code cell and run the processing job manually\"\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Processingジョブの実行\n", "\n", "1時間待つ代わりに、手動でProcessingジョブを開始して、いくつかの分析結果を取得できます。 これを行うために、カスタム画像の画像URIを取得するProcessorオブジェクトを定義します。 ジョブの入力は、キャプチャされた推論リクエストとレスポンスが保存されるS3のPathになり、スケジュールされたジョブが書き込むのと同じ宛先に結果を出力します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import Processor\n", "\n", "processor = Processor(\n", " role=role,\n", " image_uri=processing_repository_uri,\n", " instance_count=1,\n", " instance_type='ml.m5.large',\n", " env={ 'THRESHOLD':'0.5' },\n", ")\n", " \n", "processor.run(\n", " [ProcessingInput(input_name='data',\n", " source=f's3://{sagemaker_session.default_bucket()}/endpoint/data_capture',\n", " destination='/opt/ml/processing/input/endpoint/',\n", " )],\n", " [ProcessingOutput(output_name='model_outputs',\n", " source='/opt/ml/processing/outputs',\n", " destination=destination,\n", " )],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 予期しないモデルの動作をキャプチャする" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! mkdir adversarial_examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "スケジュールが定義されたので、モデルをリアルタイムで監視する準備が整いました。[AdvBox Toolkit](https://github.com/advboxes/AdvBox)を使用して、データセットから「敵対的」に変更された画像を生成します。この場合、ピクセルレベルの摂動により、モデルがだまされて誤ったクラスが予測されます。 画像は元の画像と視覚的に似ています。 この敵対的なデータをモデルで実行し、監視設定で予期しない動作を識別できることを確認します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import utils\n", "from advbox.adversarialbox.adversary import Adversary\n", "from advbox.adversarialbox.attacks.deepfool import DeepFoolAttack\n", "from advbox.adversarialbox.models.pytorch import PytorchModel\n", "\n", "model = utils.load_model()\n", "\n", "m = PytorchModel(model, None, (-3, 3), channel_axis=1)\n", "\n", "attack = DeepFoolAttack(m)\n", "attack_config = { 'iterations': 100, 'overshoot': 0.02 }\n", "\n", "dataloader = utils.get_dataloader() # GTSRB/Final_Testからdatasetとdataloaderを作成" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "敵対的な画像を生成します\n", "\n", "次のセルはml.t2.mediumの場合、完了まで15minほどかかります" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "import numpy as np\n", "\n", "\n", "for index, (inputs, labels) in enumerate(dataloader):\n", "\n", " adversary = Adversary(inputs.cpu().numpy(), None)\n", "\n", " tlabel = 14 #class label for stop sign\n", " adversary.set_target(is_targeted_attack=True, target_label=tlabel)\n", "\n", " adversary = attack(adversary, **attack_config)\n", "\n", " if adversary.is_successful():\n", "\n", " adv=adversary.adversarial_example[0]\n", "\n", " utils.show_images_diff(inputs, adv, adversary.adversarial_label, signnames, index) # オリジナル、生成画像、差分の表示と生成画像の保存\n", "\n", " if index == 100:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "それでは敵対的な画像をエンドポイントに送信してみましょう" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "ncols = 4\n", "\n", "for index, file in enumerate(glob.glob('adversarial_examples/*png')):\n", " # Load image file to array:\n", " image = Image.open(file)\n", "\n", " # Invoke the endpoint: (Returns a 2D 1x1 array)\n", " result = predictor.predict(image)\n", "\n", " # Plot the results:\n", " ixcol = index % ncols\n", " if (ixcol == 0):\n", " plt.show()\n", " fig = plt.figure(figsize=(ncols*4.5, 4.5))\n", " plt.subplot(1, ncols, ixcol + 1)\n", " plt.title(signnames[result[0][0]])\n", " plt.imshow(image)\n", " plt.axis('off')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Processing jobを実行する\n", "\n", "SageMaker Model Monitorは1時間ごとにProcessingジョブを実行しますが、以前と同様に手動でProcessingジョブを実行することもできます。\n", "\n", "これは敵対的な画像の送信をModel Monitorの分析に反映させるためのものです" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "processor.run(\n", " [ProcessingInput(input_name='data',\n", " source=f's3://{sagemaker_session.default_bucket()}/endpoint/data_capture',\n", " destination='/opt/ml/processing/input/endpoint/',\n", " )],\n", " [ProcessingOutput(output_name='model_outputs',\n", " source='/opt/ml/processing/outputs',\n", " destination=destination,\n", " )],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "SageMaker Model Monitorは、次のProcessingジョブをスケジュールするときに、Amazon S3でキャプチャおよび保存された予測結果を分析します。 前述のように、Processingジョブは単純なサニティチェックを実行するだけです。予測された画像クラスをカウントするだけで、1つのクラスが50%以上の確率で予測されると、アラートを発生します。advserialイメージをエンドポイントに送信したため、イメージクラス14(「一時停止の標識」)の異常なカウントが表示されます。 SageMaker Studioでジョブのステータスを追跡すると、問題が見つかったことがわかります。\n", "\n", "![](images/screenshot.png)\n", "\n", "CloudWatchログからさらに詳細を取得できます(/aws/sagemaker/ProcessingJobs内にあります)。Processingジョブは、キーが43の画像クラスであり、値がカウントであるディクショナリを出力します。 たとえば、下の出力では、エンドポイントが画像クラス:9を2つと、画像クラス:14を異常な頻度で予測したことがわかります。このクラスは322回予測されており、しきい値の50%よりも高くなっています。\n", "```\n", "Warning: Class 14 ('Stop sign') predicted more than 80 % of the time which is above the threshold\n", "Predicted classes {9: 2, 19: 1, 25: 1, 14: 322, 13: 5, 5: 1, 8: 10, 18: 1, 31: 4, 26: 8, 33: 4, 36: 4, 29: 20, 12: 8, 22: 4, 6: 4}\n", "\n", "```\n", "\n", "ディクショナリの値もCloudWatchメトリクスとして保存されるため、CloudWatchコンソールを使用してメトリクスデータのグラフを作成できます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### エンドポイントを更新してSageMaker Debugger hookを有効にする\n", "\n", "Processingジョブが問題を検出したら、モデルについてさらに洞察を得る時が来ました。 エンドポイントの推論関数を変更して、モデルからテンソルを出力します(entry_pointの引数のスクリプトが変更されています)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sagemaker_model = PyTorchModel(\n", " model_data=f's3://{sagemaker_session.default_bucket()}/model/model.tar.gz',\n", " role=role,\n", " source_dir='code',\n", " entry_point='pretrained_model_with_debugger_hook.py',\n", " framework_version='1.3.1',\n", " py_version='py3',\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "変更した推論関数を見てみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pycat code/pretrained_model_with_debugger_hook.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "model_fnに[SageMaker Debugger hook configuration](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/api.md#hook)を作成します。これはinclude_regexに出力したいテンソルの名前を示す正規表現を取ります。テンソルは、SageMakerのデフォルトバケットの「endpoint/tensors」に保存されます。\n", "\n", "次に、既存のエンドポイントを更新します。これにより、古いentry_pointファイルが新しいファイルに置き換えられます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Model.deploy()のupdate_endpointパラメーターがSageMaker SDKv2で削除されたため、\n", "# ここではいくつかの内部/プライベート関数を使用して、APIバックエンドにモデルを強制的に登録します...\n", "\n", "sagemaker_model._init_sagemaker_session_if_does_not_exist('ml.m5.xlarge')\n", "sagemaker_model._create_sagemaker_model('ml.m5.xlarge')\n", "\n", "# 次に、エンドポイントを新しく登録されたモデルバージョンにポイントします。\n", "predictor.update_endpoint(\n", " model_name=sagemaker_model.name,\n", " initial_instance_count=1,\n", " instance_type='ml.m5.xlarge',\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これで推論リクエストが行われるたびに、テンソルが記録され、Amazon S3にアップロードされます。 それでは、さらに敵対的な画像を送信しましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "ncols = 4\n", "\n", "for index, file in enumerate(glob.glob('adversarial_examples/*png')):\n", " # Load image file to array:\n", " image = Image.open(file)\n", "\n", " # Invoke the endpoint: (Returns a 2D 1x1 array)\n", " result = predictor.predict(image)\n", "\n", " # Plot the results:\n", " ixcol = index % ncols\n", " if (ixcol == 0):\n", " plt.show()\n", " fig = plt.figure(figsize=(ncols*4.5, 4.5))\n", " plt.subplot(1, ncols, ixcol + 1)\n", " plt.title(signnames[result[0][0]])\n", " plt.imshow(image)\n", " plt.axis('off')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これで顕著性(saliency)を計算して、モデルから視覚的な説明を取得できます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SageMaker Debuggerを使用して予測結果を分析する\n", "\n", "関連するテンソルをキャプチャするようにSageMaker Debugger hookを構成しました。 エンドポイントが更新され、テンソルがアップロードされたので、[trial](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/analysis.md#creating-a-trial-object)を作成し、Amazon S3からデータを読み取ります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from smdebug.trials import create_trial\n", "\n", "trial = create_trial(f's3://{sagemaker_session.default_bucket()}/endpoint/tensors')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これらのテンソルから、入力画像のどの領域が予測結果にとって最も重要であるかを説明する顕著性(saliency)マップを計算できます。 [Full-Gradient Representation for Neural Network Visualization [1]](https://arxiv.org/abs/1905.00780)で説明されている方法では、すべての中間特徴とそのバイアスが必要です。 次のセルは、バッチノルム層とダウンサンプリング層の出力の勾配と対応するバイアスを取得します。 ResNet以外のモデルを使用する場合は、次のセルの正規表現を調整する必要がある場合があります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "biases, gradients = [], []\n", "\n", "for tname in trial.tensor_names(regex='.*gradient.*bn.*output|.*gradient.*downsample.1.*output'):\n", " gradients.append(tname)\n", " \n", "for tname in trial.tensor_names(regex='^(?=.*bias)(?:(?!fc).)*$'):\n", " biases.append(tname)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BatchNormレイヤの場合、暗黙的なバイアス(implicit biases)を計算する必要があります。 次のコードセルで、必要なテンソルを取得します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bn_weights, running_vars, running_means = [], [], []\n", "\n", "for tname in trial.tensor_names(regex='.*running_mean'):\n", " running_means.append(tname)\n", " \n", "for tname in trial.tensor_names(regex='.*running_var'):\n", " running_vars.append(tname)\n", "\n", "for tname in trial.tensor_names(regex='.*bn.*weight|.*downsample.1.*weight'):\n", " bn_weights.append(tname) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "結果を正規化するヘルパー関数を定義します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def normalize(tensor):\n", " tensor = tensor - np.min(tensor)\n", " tensor = tensor / np.max(tensor)\n", " return tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BatchNormレイヤの移動平均は、全体的なバイアスを計算するときに考慮する必要がある暗黙的のバイアス(implicit bias)を導入します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_implicit_biases(bn_weights, running_vars, running_means, step):\n", " implicit_biases = []\n", " for weight_name, running_var_name, running_mean_name in zip(bn_weights, running_vars, running_means):\n", " weight = trial.tensor(weight_name).value(step_num=step, mode=modes.PREDICT)\n", " running_var = trial.tensor(running_var_name).value(step_num=step, mode=modes.PREDICT)\n", " running_mean = trial.tensor(running_mean_name).value(step_num=step, mode=modes.PREDICT)\n", " implicit_biases.append(- running_mean / np.sqrt(running_var) * weight)\n", " return implicit_biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "次のコードセルは、すべてのテンソルをフェッチし、画像ごとの顕著性(saliency)マップを計算します。 赤のピクセルは最も関連性の高いピクセルを示し、青のピクセルは画像クラスを予測するための最も関連性の低いピクセルを示します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "from smdebug import modes\n", "from smdebug.core.modes import ModeKeys\n", "import cv2\n", "import scipy.ndimage\n", "import scipy.special\n", "import utils\n", "\n", "for step in trial.steps():\n", "\n", " image_batch = trial.tensor(\"ResNet_input_0\").value(step_num=step, mode=modes.PREDICT)\n", "\n", " #compute implicit biases from batchnorm layers\n", " implicit_biases = compute_implicit_biases(bn_weights, running_vars, running_means, step)\n", "\n", " for item in range(image_batch.shape[0]):\n", "\n", " #input image\n", " image = image_batch[item,:,:,:]\n", "\n", " #get gradients of input image\n", " image_gradient = trial.tensor(\"gradient/image\").value(step_num=step, mode=modes.PREDICT)[item,:] \n", " image_gradient = np.sum(normalize(np.abs(image_gradient * image)), axis=0)\n", " saliency_map = image_gradient\n", "\n", " for gradient_name, bias_name, implicit_bias in zip(gradients, biases, implicit_biases):\n", "\n", " #get gradients and bias vectors for corresponding step\n", " gradient = trial.tensor(gradient_name).value(step_num=step, mode=modes.PREDICT)[item:item+1,:,:,:]\n", " bias = trial.tensor(bias_name).value(step_num=step, mode=modes.PREDICT) \n", " bias = bias + implicit_bias\n", "\n", " #compute full gradient\n", " bias = bias.reshape((1,bias.shape[0],1,1))\n", " bias = np.broadcast_to(bias, gradient.shape)\n", " bias_gradient = normalize(np.abs(bias * gradient))\n", "\n", " #interpolate to original image size\n", " for channel in range(bias_gradient.shape[1]):\n", " interpolated = scipy.ndimage.zoom(bias_gradient[0,channel,:,:], 128/bias_gradient.shape[2], order=1)\n", " saliency_map += interpolated \n", "\n", "\n", " #normalize\n", " saliency_map = normalize(saliency_map) \n", "\n", " #predicted class and propability\n", " predicted_class = trial.tensor(\"fc_output_0\").value(step_num=step, mode=modes.PREDICT)[item,:]\n", " print(\"Predicted class:\", np.argmax(predicted_class))\n", " scores = np.exp(np.asarray(predicted_class))\n", " scores = scores / scores.sum(0)\n", "\n", " #plot image and heatmap\n", " utils.plot_saliency_map(saliency_map, image, np.argmax(predicted_class), str(int(np.max(scores) * 100)), signnames )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleanup\n", "\n", "以下のセルで削除できない場合はコンソールから削除してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "monitor.delete_monitoring_schedule()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "time.sleep(60) # actually wait for the deletion" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }