{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ML Enablement Workshop: サービスの解約率改善シナリオ クラウド活用編\n", "\n", "## コンテンツ\n", "\n", "1. 背景\n", "1. 環境構築\n", "1. 学習をスケールする\n", "1. モデルをホスティングする\n", " 1. 性能評価\n", " 1. エンドポイントの削除\n", "1. Notebookを移行する\n", "\n", "---\n", "\n", "## 1.背景\n", "\n", "サービスの解約率を改善するために、 Studio Lab では機能的・コンピューティングリソース的に不十分な状況に直面することがあるかもしれません。例えば、重要なデータは Studio Lab に持ち出せないかもしれませんし、モデルを学習する、本番同等のトランザクションで検証するのに Studio Lab では力不足かもしれません。本Notebookでは、 Studio Lab では不十分な状況に直面した時に Amazon SageMaker を使用し機械学習の価値検証を継続する方法を解説します。 Studio Lab には AWS の機能を呼び出す AWS SDK がインストール済みで、 SageMaker への Notebook の移行を行う方法も整備されています。\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## 2.環境構築\n", "\n", "### 2.1 Studio Lab の環境構築\n", "\n", "本 Notebook を動かすための環境構築手順は本体のシナリオと同じため、先に [customer_churn.ipynb](./customer_churn.ipynb) を参照し環境構築を行ってください。\n", "Jupyter Notebookの右上にある虫の隣のボタンをクリックしカーネルを切り替えます。\n", "\n", "\n", "\n", "### 2-2. AWS へ接続するための環境構築\n", "今回のサンプルでは、Studio Lab で動かしている Notebook 上ではなく AWS 環境上でモデルを学習およびデプロイして使い方を確かめてみます。そのためには、Notebook から AWS 環境にアクセスする必要があります。その認証情報をこのステップでは設定します。\n", "\n", "IAM ユーザーを作成し、そこから得られるアクセスキーとシークレットキーを登録します\n", "AWSへアクセスするためのユーザー (IAM ユーザー) を作成します。IAM ユーザーの作成方法は以下のページを参考にします。名前は任意ですが、以降では`sagemaker-studio-lab-access`として扱います。\n", "\n", "- https://docs.aws.amazon.com/ja_jp/IAM/latest/UserGuide/id_users_create.html#id_users_create_console\n", "\n", "まずは、AWS のコンソール画面を開いて左上の検索窓で「IAM」と検索します。トップに出てくる IAM をクリックして IAM のサービスページを開きます。\n", "\n", "\n", "\n", "左側メニューから、「ユーザー」をクリックして IAM ユーザーの設定画面に遷移します。\n", "\n", "\n", "\n", "次に、「ユーザーを追加」をクリックしてユーザーの作成を開始します。\n", "\n", "\n", "\n", "ユーザー名に「sagemaker-studio-lab-access (画像では whisper-sample-user)」(他の名称でも大丈夫です)、「アクセスキー - プログラムによるアクセス」にチェックをつけます。\n", "\n", "\n", "\n", "「次のステップ」をクリックします。 \n", "その後「既存のポリシーを直接アタッチ」を選択し、ポリシーの検索で「SageMakerFullAccess」と入力します。そうすると「AWSSageMakerFullAccess」のポリシー候補が現れるのでこれを選択します。\n", "\n", "\n", "\n", "次に、検索窓に「PowerUserAccess」と検索し候補に出てきた「PowerUserAccess」を選択します。 \n", "\n", "\n", "\n", "「次のステップ」をクリックするとタグの設定画面が出てきますが、ここは特に入力せずにスキップします。 \n", "\n", "これまでに設定した項目の確認ページが出てくるので問題なければ「ユーザーの作成」をクリックします。\n", "\n", " \n", "\n", "無事ユーザーが作成されるとユーザーキーとシークレットキーが表示されるのでメモに残しておきます。これらの情報を使って Studio Lab 経由で AWS 環境にアクセスを行います。 \n", "**ここで取得されるクレデンシャル情報の扱いには十分注意してください**。\n", "\n", "\n", "\n", "\n", "次に、 Studio Lab の画面に戻って先ほど取得したアクセスキーなどの情報を登録していきます。 \n", "\n", "画面上部のメニューから 「File -> New -> Terminal」 と選択してターミナルの起動をします。 \n", "\n", "\n", "開かれたターミナルで `aws configure` を実行します。\n", "そこでアクセスキーとシークレットキーを聞かれるので先ほどメモした値を入力します。 \n", "\n", "\n", "\n", "以上で、認証情報の設定は完了です。ではこれから実際にモデルを動かしていきましょう。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2-3. SageMaker Training Instance が利用する IAM ロールを作成する\n", "\n", "学習を実行するインスタンスの権限となる、 IAM ロールを作成します。 IAM ロールの作成方法は以下のページを参考にします。名前は任意ですが、以降では`SageMakerStudioLabExecuteRole`として扱います。 \n", "\n", "- https://docs.aws.amazon.com/ja_jp/glue/latest/dg/create-an-iam-role-sagemaker-notebook.html\n", "\n", "AWS のコンソール画面に戻ります。 \n", "先ほどと同様の手順で IAM のサービス画面を開き、「ロール」を左側のメニューから選択します。 \n", "IAM ロールの画面が開かれたら「ロールの作成」ボタンをクリックします。 \n", "\n", " \n", "\n", "ロールの作成画面が表示されたら信頼されるエンティティタプとして「AWS のサービス」を選択し、ユースケースのところは下の検索欄から「SageMaker」などと検索して SageMaker を選択します。\n", "\n", "\n", "\n", "「次へ」をクリックし、「AmazonSageMakerFullAccess」のポリシーがアタッチされていることを確認します。 \n", "\n", "\n", "\n", "「次へ」をクリックし、Role 名を設定します。「StudioLabExecuteRole(画像では StudioLabWhisperExecutionRole)」と入力し、他の項目はいじらずに「ロールを作成」をクリックします。 \n", "こちらの Role 名も自由に設定して問題ありません。 \n", "\n", "\n", "\n", "作成した IAM ロールのリソースネームである ARN を取得します。 \n", "IAM ロールの画面から、検索欄で「StudioLabExecuteRole(画像では StudioLabWhisperExecutionRole)」などと入力して先ほど作成した IAM ロールを探して選択します。 \n", "\n", " \n", "\n", "IAM ロールの詳細情報が表示されるので、ARN の隣にあるコピーボタンをクリックして ARN をコピーします。\n", " \n", "\n", "コピペした値を置き換えて role の値を設定します。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "role = \"arn:aws:iam::000000000000:role/SageMakerStudioLabExecuteRole\" # コピペした値で置き換える" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "では、はじめていきましょう、はじめに利用するライブラリを読み込んでおきます。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Define IAM role\n", "from pathlib import Path\n", "import boto3\n", "import sagemaker\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "data_root = Path(\"../../data/\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.学習をスケールする\n", "\n", "Studio Lab の GPU では性能、稼働時間が足りない場合 AWS で学習を行うことができます。学習を始める前に、学習データをAmazon S3にアップロードしSageMakerから利用できるようにします。\n", "\n", "※事前に `customer_churn.ipynb` のシナリオを実行しデータを作成しておく必要があります。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "sagemaker_session = sagemaker.Session()\n", "input_train = sagemaker_session.upload_data(path=str(data_root.joinpath('interim/churn_train.csv')), key_prefix='sagemaker/DEMO-xgboost-churn')\n", "input_validation = sagemaker_session.upload_data(path=str(data_root.joinpath('interim/churn_validation.csv')), key_prefix='sagemaker/DEMO-xgboost-churn')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`input_train` と `input_validation` にはアップロードしたファイルのS3パスが保存されています。これらは csv ファイルで、学習させるには以下のようなデータである必要がありますが、先の前処理の段階でこのようなデータ形式に変換しているため、追加の処理は必要ありません。\n", "\n", "- 1列目が予測対象のデータ\n", "- ヘッダ行はなし\n", "\n", "学習に使ったモデルは XGBoost でしたので、 Amazon SageMaker が用意している XGBoost のコンテナを利用して学習します。このコンテナは、ファイルをデフォルトで libsvm 形式と認識するため、`TrainingInput`という関数を利用して、`content_type='text/csv'`を明示的に指定します。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# from sagemaker.session import s3_input\n", "from sagemaker.inputs import TrainingInput\n", "\n", "content_type='text/csv'\n", "s3_input_train = TrainingInput(input_train, content_type=content_type)\n", "s3_input_validation = TrainingInput(input_validation, content_type=content_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Amazon SageMaker は、マネージドで、分散学習が設定済みで、リアルタイム推論のためのホスティングも可能な XGBoost コンテナを用意しています。 リージョンごと、アルゴリズムごとに用意されているコンテナの URI は [Docker レジストリパスとサンプルコード](https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html)で確認できます。XGBoost のコンテナの場所を取得しましょう。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-xgboost:1.2-1'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "container = sagemaker.image_uris.retrieve(\"xgboost\", boto3.Session().region_name, \"1.2-1\")\n", "container" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "それでは学習を始めましょう。学習のためにハイパーパラメータを指定したり、学習のインスタンスの数やタイプを指定することができます。XGBoost における主要なハイパーパラメータは以下のとおりです。\n", "\n", "- `max_depth` アルゴリズムが構築する木の深さをコントロールします。深い木はより学習データに適合しますが、計算も多く必要で、overfiting になる可能性があります。たくさんの浅い木を利用するか、少数の深い木を利用するか、モデルの性能という面ではトレードオフがあります。\n", "- `subsample` 学習データのサンプリングをコントロールします。これは overfitting のリスクを減らしますが、小さすぎるとモデルのデータが不足してしまいます。\n", "- `num_round` ブースティングを行う回数をコントロールします。以前のイテレーションで学習したときの残差を、以降のモデルにどこまで利用するかどうかを決定します。多くの回数を指定すると学習データに適合しますが、計算も多く必要で、overfiting になる可能性があります。\n", "- `eta` 各ブースティングの影響の大きさを表します。大きい値は保守的なブースティングを行います。\n", "- `gamma` ツリーの成長の度合いをコントロールします。大きい値はより保守的なモデルを生成します。\n", "\n", "XGBoostのhyperparameterに関する詳細は [GitHub](https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst) もチェックしてください。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-09-15 12:22:20 Starting - Starting the training job...ProfilerReport-1663244539: InProgress\n", "...\n", "2022-09-15 12:23:13 Starting - Preparing the instances for training.........\n", "2022-09-15 12:24:45 Downloading - Downloading input data...\n", "2022-09-15 12:25:12 Training - Downloading the training image.........\n", "2022-09-15 12:26:48 Uploading - Uploading generated training model[2022-09-15 12:26:36.718 ip-10-0-172-176.ap-northeast-1.compute.internal:1 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None\n", "INFO:sagemaker-containers:Imported framework sagemaker_xgboost_container.training\n", "INFO:sagemaker-containers:Failed to parse hyperparameter objective value binary:logistic to Json.\n", "Returning the value itself\n", "INFO:sagemaker-containers:No GPUs detected (normal if no gpus installed)\n", "INFO:sagemaker_xgboost_container.training:Running XGBoost Sagemaker in algorithm mode\n", "INFO:root:Determined delimiter of CSV input is ','\n", "INFO:root:Determined delimiter of CSV input is ','\n", "INFO:root:Determined delimiter of CSV input is ','\n", "INFO:root:Determined delimiter of CSV input is ','\n", "INFO:root:Single node training.\n", "INFO:root:Train matrix has 3500 rows and 99 columns\n", "INFO:root:Validation matrix has 1000 rows\n", "[0]#011train-error:0.11743#011validation-error:0.12700\n", "[1]#011train-error:0.10429#011validation-error:0.10800\n", "[2]#011train-error:0.09714#011validation-error:0.10700\n", "[3]#011train-error:0.08600#011validation-error:0.10300\n", "[4]#011train-error:0.08457#011validation-error:0.09700\n", "[5]#011train-error:0.08143#011validation-error:0.09200\n", "[6]#011train-error:0.07714#011validation-error:0.08700\n", "[7]#011train-error:0.07343#011validation-error:0.08300\n", "[8]#011train-error:0.07029#011validation-error:0.07700\n", "[9]#011train-error:0.06914#011validation-error:0.07800\n", "[10]#011train-error:0.06657#011validation-error:0.07900\n", "[11]#011train-error:0.06543#011validation-error:0.07300\n", "[12]#011train-error:0.06343#011validation-error:0.07500\n", "[13]#011train-error:0.06286#011validation-error:0.07000\n", "[14]#011train-error:0.06286#011validation-error:0.07200\n", "[15]#011train-error:0.06400#011validation-error:0.07100\n", "[16]#011train-error:0.06286#011validation-error:0.07200\n", "[17]#011train-error:0.06200#011validation-error:0.06900\n", "[18]#011train-error:0.06000#011validation-error:0.06900\n", "[19]#011train-error:0.06000#011validation-error:0.06700\n", "[20]#011train-error:0.05971#011validation-error:0.06300\n", "[21]#011train-error:0.05914#011validation-error:0.06600\n", "[22]#011train-error:0.05914#011validation-error:0.06700\n", "[23]#011train-error:0.05857#011validation-error:0.07000\n", "[24]#011train-error:0.05800#011validation-error:0.06900\n", "[25]#011train-error:0.05800#011validation-error:0.06900\n", "[26]#011train-error:0.05629#011validation-error:0.06600\n", "[27]#011train-error:0.05571#011validation-error:0.06500\n", "[28]#011train-error:0.05514#011validation-error:0.06700\n", "[29]#011train-error:0.05486#011validation-error:0.06700\n", "[30]#011train-error:0.05571#011validation-error:0.06700\n", "[31]#011train-error:0.05371#011validation-error:0.06700\n", "[32]#011train-error:0.05143#011validation-error:0.06500\n", "[33]#011train-error:0.05200#011validation-error:0.06400\n", "[34]#011train-error:0.05229#011validation-error:0.06200\n", "[35]#011train-error:0.05114#011validation-error:0.06100\n", "[36]#011train-error:0.05114#011validation-error:0.06100\n", "[37]#011train-error:0.05000#011validation-error:0.05800\n", "[38]#011train-error:0.05000#011validation-error:0.06000\n", "[39]#011train-error:0.04886#011validation-error:0.05900\n", "[40]#011train-error:0.04771#011validation-error:0.05800\n", "[41]#011train-error:0.04657#011validation-error:0.06000\n", "[42]#011train-error:0.04629#011validation-error:0.06000\n", "[43]#011train-error:0.04600#011validation-error:0.06100\n", "[44]#011train-error:0.04429#011validation-error:0.06200\n", "[45]#011train-error:0.04400#011validation-error:0.06100\n", "[46]#011train-error:0.04400#011validation-error:0.06100\n", "[47]#011train-error:0.04200#011validation-error:0.06100\n", "[48]#011train-error:0.04229#011validation-error:0.06000\n", "[49]#011train-error:0.04229#011validation-error:0.06100\n", "[50]#011train-error:0.04114#011validation-error:0.06300\n", "[51]#011train-error:0.04086#011validation-error:0.06300\n", "[52]#011train-error:0.04086#011validation-error:0.06300\n", "[53]#011train-error:0.04200#011validation-error:0.06200\n", "[54]#011train-error:0.04229#011validation-error:0.06200\n", "[55]#011train-error:0.04086#011validation-error:0.06400\n", "[56]#011train-error:0.04086#011validation-error:0.06400\n", "[57]#011train-error:0.04086#011validation-error:0.06400\n", "[58]#011train-error:0.04057#011validation-error:0.06600\n", "[59]#011train-error:0.04086#011validation-error:0.06600\n", "[60]#011train-error:0.04086#011validation-error:0.06600\n", "[61]#011train-error:0.04086#011validation-error:0.06600\n", "[62]#011train-error:0.04057#011validation-error:0.06700\n", "[63]#011train-error:0.04057#011validation-error:0.06900\n", "[64]#011train-error:0.03914#011validation-error:0.06900\n", "[65]#011train-error:0.03886#011validation-error:0.06900\n", "[66]#011train-error:0.03857#011validation-error:0.06800\n", "[67]#011train-error:0.03857#011validation-error:0.06600\n", "[68]#011train-error:0.03857#011validation-error:0.06500\n", "[69]#011train-error:0.03600#011validation-error:0.06500\n", "[70]#011train-error:0.03629#011validation-error:0.06300\n", "[71]#011train-error:0.03629#011validation-error:0.06300\n", "[72]#011train-error:0.03600#011validation-error:0.06300\n", "[73]#011train-error:0.03571#011validation-error:0.06300\n", "[74]#011train-error:0.03571#011validation-error:0.06200\n", "[75]#011train-error:0.03571#011validation-error:0.06200\n", "[76]#011train-error:0.03543#011validation-error:0.06200\n", "[77]#011train-error:0.03686#011validation-error:0.06300\n", "[78]#011train-error:0.03686#011validation-error:0.06400\n", "[79]#011train-error:0.03686#011validation-error:0.06300\n", "[80]#011train-error:0.03714#011validation-error:0.06400\n", "[81]#011train-error:0.03657#011validation-error:0.06500\n", "[82]#011train-error:0.03657#011validation-error:0.06500\n", "[83]#011train-error:0.03629#011validation-error:0.06400\n", "[84]#011train-error:0.03629#011validation-error:0.06400\n", "[85]#011train-error:0.03543#011validation-error:0.06300\n", "[86]#011train-error:0.03543#011validation-error:0.06300\n", "[87]#011train-error:0.03400#011validation-error:0.06300\n", "[88]#011train-error:0.03457#011validation-error:0.06200\n", "[89]#011train-error:0.03429#011validation-error:0.06100\n", "[90]#011train-error:0.03429#011validation-error:0.06200\n", "[91]#011train-error:0.03457#011validation-error:0.06200\n", "[92]#011train-error:0.03429#011validation-error:0.06100\n", "[93]#011train-error:0.03457#011validation-error:0.06400\n", "[94]#011train-error:0.03457#011validation-error:0.06300\n", "[95]#011train-error:0.03457#011validation-error:0.06400\n", "[96]#011train-error:0.03457#011validation-error:0.06300\n", "[97]#011train-error:0.03457#011validation-error:0.06400\n", "[98]#011train-error:0.03486#011validation-error:0.06300\n", "[99]#011train-error:0.03457#011validation-error:0.06300\n", "\n", "2022-09-15 12:27:26 Completed - Training job completed\n", "ProfilerReport-1663244539: NoIssuesFound\n", "Training seconds: 148\n", "Billable seconds: 148\n" ] } ], "source": [ "sess = sagemaker.Session()\n", "\n", "hyperparameters = {\"max_depth\":\"5\",\n", " \"eta\":\"0.2\",\n", " \"gamma\":\"4\",\n", " \"min_child_weight\":\"6\",\n", " \"subsample\":\"0.8\",\n", " \"objective\":\"binary:logistic\",\n", " \"num_round\":\"100\"}\n", "\n", "xgb = sagemaker.estimator.Estimator(container,\n", " role, \n", " hyperparameters=hyperparameters,\n", " instance_count=1, \n", " instance_type='ml.m4.xlarge',\n", " sagemaker_session=sess)\n", "\n", "xgb.fit({'train': s3_input_train, 'validation': s3_input_validation}) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習ジョブは AWS Console からも確認できます。\n", "\n", "\n", "\n", "学習したモデルは S3 に格納されています。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'s3://sagemaker-ap-northeast-1-585936743357/sagemaker-xgboost-2022-09-15-12-22-19-369/output/model.tar.gz'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgb.model_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 4.モデルをホスティングする\n", "\n", "A/B テストを行う場合などは、別々に学習したモデルを API サーバーとして立てる必要があるかもしれません。 SageMaker では、学習が終われば`deploy()`を実行することで、エンドポイントを作成してモデルをデプロイできます。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------!" ] } ], "source": [ "xgb_predictor = xgb.deploy(initial_instance_count=1, instance_type = 'ml.m4.xlarge')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "デプロイしたモデルは AWS Console から確認できます。\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4-1.性能評価\n", "\n", "ホスティングしたモデルを使用し、簡単に予測を行うことができます。予測は http の POST の request を送るだけです。\n", "endpoint は `numpy` の `array` を受け取ることができないため、[`CSVSerializer`](https://sagemaker.readthedocs.io/en/stable/api/inference/serializers.html#sagemaker.serializers.CSVSerializer) を設定して `numpy` の `array` を csv 形式に変換して送ります。 逆に、endpoint から取得する時は csv からリストに変換します。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "xgb_predictor.serializer = sagemaker.serializers.CSVSerializer()\n", "xgb_predictor.deserializer = sagemaker.deserializers.CSVDeserializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "先のノートブックで作成済みのテストデータを受け取ると、これをデフォルト500行ずつのデータにわけて、エンドポイントに送信する `predict` という関数を用意します。あとは `predict` を実行して予測結果を受け取ります。 " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def predict(data, rows=500):\n", " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", " predictions = []\n", " for array in split_array:\n", " predictions.append(np.array(xgb_predictor.predict(array), dtype=np.float32))\n", "\n", " return np.concatenate(predictions, axis=1)\n", "\n", "test_data = pd.read_csv(data_root.joinpath('interim/churn_test.csv'), header=None)\n", "predictions = predict(test_data.values[:, 1:]) # 0列目はラベルのため除外" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "機械学習の性能を比較評価する方法はいくつかありますが、単純に、予測値と実際の値を比較しましょう。今回は、顧客が離反する `1` と離反しない `0` を予測しますので、この混同行列を作成します。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
predictions | \n", "0.0 | \n", "1.0 | \n", "
---|---|---|
actual | \n", "\n", " | \n", " |
0 | \n", "235 | \n", "18 | \n", "
1 | \n", "11 | \n", "236 | \n", "