{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# モデル品質モニタリングのステップB" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このノートブックを実行する時のヒント: \n", "- このノートブックは大容量のRawデータを読み込むため、メモリー8GB以上のインスタンスで実行してください\n", "- KernelはPython3(Data Science)で動作確認をしています。\n", "- デフォルトではSageMakerのデフォルトBucketを利用します。必要に応じて変更することも可能です。\n", "- 実際に動かさなくても出力を確認できるようにセルのアウトプットを残しています。きれいな状態から実行したい場合は、右クリックメニューから \"Clear All Outputs\"を選択して出力をクリアしてから始めてください。\n", "- 作成されたスケジュールはSageMaker Studioの`SageMaker resource` (左側ペインの一番下)のEndpointメニューからも確認可能" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "複数のノートブックで共通で使用する変数" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# エンドポイント名を指定する\n", "endpoint_name = 'nyctaxi-xgboost-endpoint'\n", "\n", "# エンドポイントConfigの名前を指定する\n", "endpoint_config_name = f'{endpoint_name}-config'\n", "\n", "# データ品質のモニタリングスケジュールの名前を指定する\n", "model_quality_monitoring_schedule = f'{endpoint_name}-model-quality-schedule'\n", "\n", "# SageMaker default bucketをModel Monitorのバケットとして使用\n", "# それ以外のバケットを使用している場合はここで指定する\n", "import sagemaker\n", "bucket = sagemaker.Session().default_bucket()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "モニタリング結果を保管するための、ベースラインやレポートのS3上のPrefixを設定します" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## B1(オプションA) 推論を実行してGround TruthをS3にアップロードする\n", "推論の実行後に次の周期のモニタリングジョブを待つ必要があります" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# ベースラインの出力先Prefixを設定する\n", "baseline_prefix = 'model_monitor/model_quality_baseline'\n", "\n", "# 時系列での可視化のために、複数のレポートに共通するPrefixを設定する\n", "report_prefix = 'model_monitor/model_quality_monitoring_report'\n", "\n", "# Ground Truthをアップロードする先のPrefixを指定します\n", "ground_truth_prefix = 'model_monitor/model_quality_ground_truth'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inference IDを指定して推論を実行する" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 推論を実行する日付を指定する\n", "prediction_target_date = '2021-09-15'\n", "\n", "# データのサンプリングレートを指定する(モデル作成時の設定に合わせる)\n", "sampling_rate = 20\n", "\n", "# 推論結果を保存するディレクトリ名を指定する\n", "result_dir = 'prediction_results_model_quality'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import os\n", "import boto3\n", "import pandas as pd\n", "import time\n", "from datetime import datetime\n", "import model_utils" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def get_data_for_pred(target, sampling_rate):\n", " previous_year, previous_month = model_utils.get_previous_year_month(target.year, target.month)\n", " df_previous_month = model_utils.get_raw_data(previous_year, previous_month, sampling_rate)\n", " df_current_month = model_utils.get_raw_data(target.year, target.month, sampling_rate)\n", " df_data = pd.concat([df_previous_month, df_current_month])\n", " del df_previous_month\n", " del df_current_month\n", "\n", " # Extract features\n", " df_features = model_utils.extract_features(df_data)\n", " df_features = model_utils.filter_current_month(df_features, target.year, target.month)\n", " \n", " return df_features" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data for 2021-09\n", "Predicting 2021-09-15 00:00:00 nyctaxi-xgboost-endpoint\n", "Prediciton completed. Result file: prediction_results_model_quality/prediction-result-2021-09-15.csv\n" ] } ], "source": [ "# Create result directory if not exist\n", "if not os.path.exists(result_dir):\n", " os.makedirs(result_dir)\n", "\n", "target_date = pd.to_datetime(prediction_target_date)\n", "print('Loading data for', target_date.strftime('%Y-%m'))\n", "df_features = get_data_for_pred(target_date, sampling_rate)\n", " \n", "# Exec prediction for the target date\n", "print('Predicting', target_date, endpoint_name)\n", "df_pred = df_features[df_features.index == target_date].copy()\n", "df_pred[['pred', 'inference_id']] = model_utils.exec_prediction(endpoint_name, df_pred)\n", "\n", "# Save prediction result\n", "result_file = f'{result_dir}/prediction-result-{prediction_target_date}.csv'\n", "df_pred.to_csv(result_file, index=False)\n", "print('Prediciton completed. Result file: ', result_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 推論時に取得したInference IDとGround TruthをマージしてS3にアップロードする" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "import pandas as pd\n", "import io\n", "import json\n", "import sagemaker\n", "from sagemaker.s3 import S3Uploader\n", "from datetime import datetime\n", "\n", "s3r = boto3.resource('s3')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Ground Truth(今回はpickup_countカラム)および、inference id(今回はinference_id)が格納されたファイルを取得する\n", "# ここではローカルファイルのcsvに1日分の推論実行結果が保存されており、Ground TruthとInferenceが同一ファイルに格納されていると想定する。\n", "ground_truth_colname = 'pickup_count'\n", "inference_id_colname = 'inference_id'\n", "\n", "# Ground TruthをアップロードするPrefixを設定\n", "# create_monitoring_scheduleを実行した際のモニタリングジョブの設定と一致させる\n", "bucket = sagemaker.Session().default_bucket()\n", "ground_truth_path = f's3://{bucket}/{ground_truth_prefix}'" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", " | pickup_count | \n", "history_12slots | \n", "history_16slots | \n", "history_20slots | \n", "history_24slots | \n", "history_28slots | \n", "history_32slots | \n", "history_36slots | \n", "history_40slots | \n", "history_44slots | \n", "... | \n", "tolls_amount_mean_20slot | \n", "tolls_amount_mean_96slot | \n", "tolls_amount_mean_100slot | \n", "tolls_amount_mean_104slot | \n", "tolls_amount_mean_192slot | \n", "tolls_amount_mean_196slot | \n", "tolls_amount_mean_200slot | \n", "time_slot | \n", "pred | \n", "inference_id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "113 | \n", "264.0 | \n", "324.0 | \n", "350.0 | \n", "359.0 | \n", "369.0 | \n", "362.0 | \n", "352.0 | \n", "301.0 | \n", "294.0 | \n", "... | \n", "0.352771 | \n", "0.440594 | \n", "0.674265 | \n", "0.203778 | \n", "0.169022 | \n", "0.544014 | \n", "0.699157 | \n", "0 | \n", "88 | \n", "f7642aa4-4496-43be-9185-89e9cfe14c53 | \n", "
1 | \n", "91 | \n", "249.0 | \n", "274.0 | \n", "330.0 | \n", "391.0 | \n", "351.0 | \n", "320.0 | \n", "343.0 | \n", "340.0 | \n", "279.0 | \n", "... | \n", "0.277879 | \n", "0.786000 | \n", "0.751639 | \n", "0.463443 | \n", "0.883146 | \n", "0.681200 | \n", "0.872543 | \n", "1 | \n", "71 | \n", "a0911496-6e3e-4984-84e2-349152581687 | \n", "
2 | \n", "81 | \n", "273.0 | \n", "270.0 | \n", "327.0 | \n", "405.0 | \n", "376.0 | \n", "363.0 | \n", "333.0 | \n", "316.0 | \n", "298.0 | \n", "... | \n", "0.260398 | \n", "0.777119 | \n", "0.564286 | \n", "0.536010 | \n", "1.105844 | \n", "0.661009 | \n", "0.595752 | \n", "2 | \n", "72 | \n", "bed71a18-4006-461d-863d-e9c1ec7e65d6 | \n", "
3 rows × 145 columns
\n", "\n", " | mae | \n", "mse | \n", "rmse | \n", "r2 | \n", "
---|---|---|---|---|
2020-01-06 01:00:00 | \n", "48.736082 | \n", "3559.985828 | \n", "59.665617 | \n", "0.915222 | \n", "
2020-01-13 01:00:00 | \n", "33.640843 | \n", "2203.995053 | \n", "46.946726 | \n", "0.959612 | \n", "
2020-01-20 01:00:00 | \n", "56.177981 | \n", "5339.805945 | \n", "73.073976 | \n", "0.840637 | \n", "
2020-01-27 01:00:00 | \n", "32.058851 | \n", "1598.723379 | \n", "39.984039 | \n", "0.966500 | \n", "
2020-02-03 01:00:00 | \n", "39.323701 | \n", "2523.980451 | \n", "50.239232 | \n", "0.948893 | \n", "
2020-02-10 01:00:00 | \n", "41.485255 | \n", "2958.223253 | \n", "54.389551 | \n", "0.949442 | \n", "
2020-02-17 01:00:00 | \n", "47.202256 | \n", "3553.185067 | \n", "59.608599 | \n", "0.865158 | \n", "
2020-02-24 01:00:00 | \n", "39.585234 | \n", "2842.260041 | \n", "53.312851 | \n", "0.942540 | \n", "
2020-03-02 01:00:00 | \n", "41.083948 | \n", "2872.077381 | \n", "53.591766 | \n", "0.942076 | \n", "
2020-03-09 01:00:00 | \n", "61.858424 | \n", "7275.347270 | \n", "85.295646 | \n", "0.812754 | \n", "
2020-03-16 01:00:00 | \n", "101.941843 | \n", "14074.934422 | \n", "118.637829 | \n", "-1.424614 | \n", "
2020-03-23 01:00:00 | \n", "220.142144 | \n", "57067.685159 | \n", "238.888437 | \n", "-176.484370 | \n", "
2020-03-30 01:00:00 | \n", "190.397829 | \n", "45052.495289 | \n", "212.255731 | \n", "-274.076476 | \n", "
2020-04-06 01:00:00 | \n", "92.610319 | \n", "9805.263248 | \n", "99.021529 | \n", "-85.943666 | \n", "
2020-04-13 01:00:00 | \n", "63.392568 | \n", "4285.378399 | \n", "65.462802 | \n", "-37.234659 | \n", "
2020-04-20 01:00:00 | \n", "62.310475 | \n", "4175.124192 | \n", "64.615201 | \n", "-34.572366 | \n", "
2020-04-27 01:00:00 | \n", "61.910881 | \n", "4149.062384 | \n", "64.413216 | \n", "-25.887938 | \n", "
2020-05-04 01:00:00 | \n", "59.307794 | \n", "3856.049834 | \n", "62.097100 | \n", "-18.118762 | \n", "
2020-05-11 01:00:00 | \n", "60.788436 | \n", "4143.170986 | \n", "64.367468 | \n", "-18.026519 | \n", "
2020-05-18 01:00:00 | \n", "59.132732 | \n", "3848.765692 | \n", "62.038421 | \n", "-16.037470 | \n", "
\n", " | eventId | \n", "inferenceId | \n", "inferenceTime | \n", "
---|---|---|---|
0 | \n", "dcbf4055-1196-4b71-ab54-e290b4b3530e | \n", "50f38bbc-1c15-4d81-ac2a-7a922d8adef6 | \n", "2022-12-11T07:27:10Z | \n", "
1 | \n", "53b77003-449e-4ee9-9eea-88ac0b66df40 | \n", "b5d9a7ac-9eaf-4fa4-8549-c4b1e2c78492 | \n", "2022-12-11T07:27:10Z | \n", "
2 | \n", "07e28837-40c6-4731-a1fa-dcdfbf543ec6 | \n", "20b5e6f6-0b84-44fc-8977-68da623391d1 | \n", "2022-12-11T07:27:11Z | \n", "
3 | \n", "3004ec50-6b30-4043-b149-76e3322016cd | \n", "37b11b6d-1087-4378-948d-247a86da68f6 | \n", "2022-12-11T07:27:11Z | \n", "
4 | \n", "9336b2ff-dea3-4095-9d9c-f882cab4e7b3 | \n", "f6bed701-025d-4aac-9941-c90e9ede4509 | \n", "2022-12-11T07:27:11Z | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
83 | \n", "eb42e34b-93e5-46e5-9f28-68807f93348c | \n", "db5843e2-a983-4a19-b6fe-ae6c8a321ce9 | \n", "2022-12-11T07:27:15Z | \n", "
84 | \n", "72846ebd-4f6b-4571-9f09-554d792f5609 | \n", "190dbd03-4479-4706-bbe0-0b101e1a6fe2 | \n", "2022-12-11T07:27:15Z | \n", "
85 | \n", "bddc4048-e81c-4bbf-be19-d2359f5d8a0b | \n", "bf1842fd-2884-4f2b-9932-8a2acb9b754c | \n", "2022-12-11T07:27:15Z | \n", "
86 | \n", "38c518b2-5d79-4cd6-94b6-6497bb9e992b | \n", "fb6c7636-cce2-4727-a79e-599ea8c1fd1d | \n", "2022-12-11T07:27:15Z | \n", "
87 | \n", "1d62673c-fb3a-4dad-94cf-50497d76723a | \n", "ad8d3c31-64ef-4b0a-92b0-e34e870d9fbf | \n", "2022-12-11T07:27:15Z | \n", "
88 rows × 3 columns
\n", "