{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# XGBoost による顧客離反分析 (Churn Analysis)\n", "\n", "---\n", "\n", "---\n", "\n", "## コンテンツ\n", "\n", "1. [背景](#1.背景)\n", "1. [セットアップ](#2.セットアップ)\n", "1. [データ](#3.データ)\n", "1. [学習](#4.学習)\n", "1. [ホスティング](#5.ホスティング)\n", " 1. [評価](#5-1.評価)\n", " 1. [推論エラーのコスト](#5-2.推論エラーのコスト)\n", " 1. [最適な閾値を探す](#5-3.最適な閾値を探す)\n", "1. [エンドポイントの削除](#6.エンドポイントの削除)\n", "---\n", "\n", "## 1.背景\n", "\n", "_このノートブックで実施する内容は、[AWS blog post](https://aws.amazon.com/blogs/ai/predicting-customer-churn-with-amazon-machine-learning/)にも記載されています。_\n", "\n", "どのようなビジネスであっても、顧客を失うことは大きな損害です。もし、満足していない顧客を早期に見つけることができれば、そのような顧客をキープするためのインセンティブを提供できる可能性があるでしょう。このノートブックでは、満足していない顧客を自動で認識するために機械学習 (Machine Learning, ML) を利用する方法を説明します。このような顧客の離反分析は Customer Churn Prediction と呼ばれています。機械学習モデルは完璧な予測を行えないので、このノートブックでは予測のエラーが生じたときの相対的なコストを考慮して、機械学習を利用したときの成果を金額で評価します。\n", "\n", "ここでは、私達にとってなじみのある離反分析、携帯電話会社からの離反を取り上げます。携帯電話会社が、ある顧客が離反しそうと察知したら、その顧客にタイムリーにインセンティブを与えます。つまり、電話をアップグレードしたり、新しい機能を使えるようになったりして、引き続き携帯電話会社を使おうと思うかもしれません。インセンティブは、顧客が離反して再度獲得するまでにかかるコストよりもずっと小さいことが多いです。\n", "\n", "\n", "\n", "---\n", "\n", "## 2.セットアップ\n", "\n", "まず、このノートブックインスタンスに付与されている IAM role を `get_execution_role()` から取得しましょう。後ほど、SageMaker の学習やホスティングを行いますが、そこで IAM role が必要になります。そこで、ノートブックインスタンスの IAM role を、学習やホスティングでも利用します。\n", "通常、role を取得するためにはAWS SDKを利用した数行のコードを書く必要があります。ここでは `get_execution_role()` のみで role を取得可能です。SageMaker Python SDK は、データサイエンティストが機械学習以外のコードを簡潔に済ませるために、このような関数を提供しています。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true }, "outputs": [], "source": [ "# bucket = ''\n", "# prefix = 'sagemaker/DEMO-xgboost-churn'\n", "\n", "# Define IAM role\n", "import boto3\n", "import re\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以降で利用するライブラリをここで読み込んでおきます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import io\n", "import os\n", "import sys\n", "import time\n", "import json\n", "from IPython.display import display\n", "from time import strftime, gmtime\n", "import sagemaker\n", "\n", "print('Current SageMaker Python SDK Version ={0}'.format(sagemaker.__version__))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上記セルを実行して、SageMaker Python SDK Version が 1.xx.x の場合、以下のセルのコメントアウトを解除してから実行してください。実行が完了したら、上にあるメニューから [Kernel] -> [Restart] を選択してカーネルを再起動してください。\n", "\n", "再起動が完了したら、このノートブックの一番上のセルから再度実行してください。その場合、以下のセルを実行する必要はありません。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !pip install -U --quiet \"sagemaker==2.16.1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.データ\n", "\n", "携帯電話会社は、どの顧客が最終的に離反したか、または、サービスを使い続けたかの履歴データをもっています。この履歴データに対して学習を行うことで、携帯電話会社の顧客離反を予想するモデルを構築します。モデルの学習が終わった後、任意の顧客のデータ (モデルの学習で利用したものと同じ情報を利用します)をモデルに入力すると、モデルはその顧客が離反しそうかどうかを予測します。もちろん、モデルは誤って予測することも考えられるので、将来を予測することはやはり難しいですが、そのような誤りに対応する方法も紹介します。\n", "\n", "ここで利用するデータセットは一般的に利用可能で、書籍 [Discovering Knowledge in Data](https://www.amazon.com/dp/0470908742/) の中で Daniel T. Larose が言及しているものです。そのデータセットは、著者によって University of California Irvine Repository of Machine Learning Datasets に提供されています。ここでは、そのデーセットをダウンロードして読み込んでみます。\n", "\n", "Jupyter notebook では、冒頭に `!` を入力することで、シェルコマンドを実行することができます。AWS CLIを用いてS3からデータをダウンロードします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 cp s3://sagemaker-sample-files/datasets/tabular/synthetic/churn.txt ./" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "カレントディレクトリ下にダウンロードされた `churn.txt` を `pandas` を利用して読み込んでみます。 `pandas` は、表形式のデータを読み込んで、様々な加工ができるライブラリです。例えば、以下を実行すると表形式でのデータ表示が可能です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "churn = pd.read_csv('./churn.txt')\n", "pd.set_option('display.max_columns', 500)\n", "churn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "データをみると 5,000 行のデータしかなく、現在の機械学習の状況から見ると、やや小さいデータセットです。各データのレコードは、ある米国の携帯電話会社の顧客のプロフィールを説明する21の属性からなります。その属性というのは、\n", "\n", "- `State`: 顧客が居住している米国州で、2文字の省略形で記載されます (OHとかNJのように)\n", "- `Account Length`: アカウントが利用可能になってからの経過日数\n", "- `Area Code`: 顧客の電話番号に対応する3桁のエリアコード\n", "- `Phone`: 残りの7桁の電話番号\n", "- `Int’l Plan`: 国際電話のプランに加入しているかどうか (yes/no)\n", "- `VMail Plan`: Voice mail の機能を利用しているかどうか (yes/no)\n", "- `VMail Message`: 1ヶ月の Voice mail のメッセージの平均長\n", "- `Day Mins`: 1日に通話した時間(分)の総和\n", "- `Day Calls`: 1日に通話した回数の総和\n", "- `Day Charge`: 日中の通話にかかった料金\n", "- `Eve Mins, Eve Calls, Eve Charge`: 夜間通話にかかった料金\n", "- `Night Mins`, `Night Calls`, `Night Charge`: 深夜通話にかかった料金\n", "- `Intl Mins`, `Intl Calls`, `Intl Charge`: 国際通話にかかった料金\n", "- `CustServ Calls`: カスタマーサービスに電話をかけた回数\n", "- `Churn?`: そのサービスから離反したかどうか (true/false)\n", "\n", "最後の属性 `Churn?` は目的変数として知られ、MLのモデルで予測する属性になります。目的変数は2値 (binary) なので、ここで作成するモデルは2値の予測を行います。これは2値分類といわれます。\n", "\n", "それではデータを詳しく見てみます。\n", "\n", "まずはカテゴリデータごとにデータの頻度をみてみます。カテゴリデータは、`State`, `Area code`, `Phone`, `Int’l Plan`, `VMail Plan`, `Churn?`で、カテゴリを表す文字列や数値がデータとして与えられているものです。`pandas`ではある程度自動で、カテゴリデータを認識し、`object`というタイプでデータを保存します。以下では、`object` 形式のデータをとりだして、カテゴリごとの頻度を表示します。\n", "\n", "また `describe()`を利用すると各属性の統計量を一度に見ることができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Frequency tables for each categorical feature\n", "for column in churn.select_dtypes(include=['object']).columns:\n", " display(pd.crosstab(index=churn[column], columns='% observations', normalize='columns'))\n", "\n", "# Histograms for each numeric features\n", "display(churn.describe())\n", "%matplotlib inline\n", "hist = churn.hist(bins=30, sharey=True, figsize=(10, 10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "データを見てみると以下のことに気づくと思います。\n", "\n", "- `State` の各頻度はだいたい一様に分布しています。\n", "- `Phone` はすべて同じ数値になっていて手がかりになりそうにありません。この電話番号の最初の3桁はなにか意味がありそうですが、その割当に意味がないのであれば、使うのは止めるべきでしょう\n", "- 数値的な特徴量は都合の良い形で分布しており、多くは釣り鐘のようなガウス分布をしています。ただ、`VMail Message`は例外です。\n", "- `Area code` は数値データとみなされているようなので、非数値に変換しましょう\n", "\n", "さて、実際に`Phone`の列を削除して、`Area code`を非数値に変換します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "churn = churn.drop('Phone', axis=1)\n", "churn['Area Code'] = churn['Area Code'].astype(object)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "それでは次に各属性の値を、目的変数の True か False か、にわけて見てみます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for column in churn.select_dtypes(include=['object']).columns:\n", " if column != 'Churn?':\n", " display(pd.crosstab(index=churn[column], columns=churn['Churn?'], normalize='columns'))\n", "\n", "for column in churn.select_dtypes(exclude=['object']).columns:\n", " print(column)\n", " hist = churn[[column, 'Churn?']].hist(by='Churn?', bins=30)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "データ分析の結果から、離反する顧客について、以下のような傾向が考えられます。\n", "\n", "- 地理的にもほぼ一様に分散している\n", "- 国際通話を利用している\n", "- VoiceMailを利用していない\n", "- 通話時間で見ると長い通話時間と短い通話時間の人に分かれる\n", "- カスタマーサービスへの通話が多い (多くの問題を経験した顧客ほど離反するというのは理解できる)\n", "\n", "加えて、離反する顧客に関しては、`Day Mins` と `Day Charge` で似たような分布を示しています。しかし、話せば話すほど、通常課金されるので、驚くことではないです。もう少し深く調べてみましょう。`corr()` を利用すると相関係数を求めることができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "display(churn.corr())\n", "pd.plotting.scatter_matrix(churn, figsize=(12, 12))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "いくつかの特徴は互いに100%の相関をもっています。このような特徴があるとき、機械学習のアルゴリズムによっては全くうまくいかないことがあり、そうでなくても結果が偏ったりしてしまうことがあります。これらの相関の強いペアは削除しましょう。Day Mins に対する Day Charge、Night Mins に対する Night Charge、Intl Mins に対する Intl Charge を削除します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "churn = churn.drop(['Day Charge', 'Eve Charge', 'Night Charge', 'Intl Charge'], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ここまででデータセットの前処理は完了です。これから利用する機械学習のアルゴリズムを決めましょう。前述したように、数値の大小 (中間のような数値ではなく)で離反を予測するような変数を用意すると良さそうです。線形回帰のようなアルゴリズムでこれを行う場合は、複数の項(もしくはそれらをまとめた項)を属性として用意する必要があります。\n", "\n", "そのかわりに、これを勾配ブースティング木 (Gradient Boosted Tree)を利用しましょう。Amazon SageMaker は、マネージドで、分散学習が設定済みで、リアルタイム推論のためのホスティングも可能な XGBoost コンテナを用意しています。XGBoost は、特徴感の非線形な関係を考慮した勾配ブースティング木を利用しており、特徴間の複雑な関連性を扱うことができます。\n", "\n", "Amazon SageMaker の XGBoostは、csv または LibSVM 形式のデータを学習することができます。ここでは csv を利用します。csv は以下のようなデータである必要があります。\n", "\n", "- 1列目が予測対象のデータ\n", "- ヘッダ行はなし\n", "\n", "まずはじめに、カテゴリ変数を数値データに変換する必要があります。`get_dummies()` を利用すると数値データへの変換が可能です。\n", "\n", "そして、`Churn?_True`のデータを最初の列にもってきて、`Churn?_False.`, `Churn?_True.`のデータを削除した残りのデータをconcatenate (連結) します。\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_data = pd.get_dummies(churn)\n", "model_data = pd.concat([model_data['Churn?_True.'], model_data.drop(['Churn?_False.', 'Churn?_True.'], axis=1)], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ここで学習用、バリデーション用、テスト用データにわけましょう。これによって overfitting (学習用データには精度が良いが、実際に利用すると制度が悪い、といった状況) を回避しやすくなり、未知のテストデータに対する精度を確認することができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9 * len(model_data))])\n", "train_data.to_csv('train.csv', header=False, index=False)\n", "validation_data.to_csv('validation.csv', header=False, index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習には学習用データとバリデーション用データのみが必要です。上で csv に出力したデータをS3にアップロードして学習に利用できるようにします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sagemaker_session = sagemaker.Session()\n", "input_train = sagemaker_session.upload_data(path='train.csv', key_prefix='sagemaker/DEMO-xgboost-churn')\n", "input_validation = sagemaker_session.upload_data(path='validation.csv', key_prefix='sagemaker/DEMO-xgboost-churn')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`input_train` と `input_validation` にはアップロードしたファイルのS3パスが保存されています。これらは csv ファイルですが、Amazon SageMaker が用意している XGBoost のコンテナは、ファイルをデフォルトで libsvm 形式と認識してしまうため、このままだとエラーが発生します。\n", "`TrainingInput`という関数を利用して、`content_type='text/csv'`を明示的に指定することで、csv 形式と認識させることができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from sagemaker.session import s3_input\n", "from sagemaker.inputs import TrainingInput\n", "\n", "content_type='text/csv'\n", "s3_input_train = TrainingInput(input_train, content_type=content_type)\n", "s3_input_validation = TrainingInput(input_validation, content_type=content_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 4.学習\n", "\n", "それでは学習を始めましょう。まず、XGBoost のコンテナの場所を取得します。コンテナ自体は SageMaker 側で用意されているので、場所を指定すれば利用可能です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "container = sagemaker.image_uris.retrieve(\"xgboost\", boto3.Session().region_name, \"1.2-1\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習のためにハイパーパラメータを指定したり、学習のインスタンスの数やタイプを指定することができます。XGBoost における主要なハイパーパラメータは以下のとおりです。\n", "\n", "- `max_depth` アルゴリズムが構築する木の深さをコントロールします。深い木はより学習データに適合しますが、計算も多く必要で、overfiting になる可能性があります。たくさんの浅い木を利用するか、少数の深い木を利用するか、モデルの性能という面ではトレードオフがあります。\n", "- `subsample` 学習データのサンプリングをコントロールします。これは overfitting のリスクを減らしますが、小さすぎるとモデルのデータが不足してしまいます。\n", "- `num_round` ブースティングを行う回数をコントロールします。以前のイテレーションで学習したときの残差を、以降のモデルにどこまで利用するかどうかを決定します。多くの回数を指定すると学習データに適合しますが、計算も多く必要で、overfiting になる可能性があります。\n", "- `eta` 各ブースティングの影響の大きさを表します。大きい値は保守的なブースティングを行います。\n", "- `gamma` ツリーの成長の度合いをコントロールします。大きい値はより保守的なモデルを生成します。\n", "\n", "XGBoostのhyperparameterに関する詳細は [github](https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst) もチェックしてください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sess = sagemaker.Session()\n", "\n", "hyperparameters = {\"max_depth\":\"5\",\n", " \"eta\":\"0.2\",\n", " \"gamma\":\"4\",\n", " \"min_child_weight\":\"6\",\n", " \"subsample\":\"0.8\",\n", " \"objective\":\"binary:logistic\",\n", " \"num_round\":\"100\"}\n", "\n", "xgb = sagemaker.estimator.Estimator(container,\n", " role, \n", " hyperparameters=hyperparameters,\n", " instance_count=1, \n", " instance_type='ml.m4.xlarge',\n", " sagemaker_session=sess)\n", "\n", "xgb.fit({'train': s3_input_train, 'validation': s3_input_validation}) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 5.ホスティング\n", "\n", "学習が終われば、`deploy()`を実行することで、エンドポイントを作成してモデルをデプロイできます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xgb_predictor = xgb.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5-1.評価\n", "\n", "現在、エンドポイントをホストしている状態で、これを利用して簡単に予測を行うことができます。予測は http の POST の request を送るだけです。\n", "ここではデータを `numpy` の `array` の形式で送って、予測を得られるようにしたいと思います。しかし、endpoint は `numpy` の `array` を受け取ることはできません。\n", "\n", "このために、`csv_serializer` を利用して、csv 形式に変換して送ることができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xgb_predictor.serializer = sagemaker.serializers.CSVSerializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "作成済みのテストデータを受け取ると、これをデフォルト500行ずつのデータにわけて、エンドポイントに送信する `predict` という関数を用意します。あとは `predict` を実行して予測結果を受け取ります。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def predict(data, rows=500):\n", " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", " predictions = ''\n", " for array in split_array:\n", " predictions = ','.join([predictions, xgb_predictor.predict(array).decode('utf-8')])\n", "\n", " return np.fromstring(predictions[1:], sep=',')\n", "\n", "dtest = test_data.values\n", "predictions = []\n", "predictions.append(predict(dtest[:, 1:]))\n", "predictions = np.array(predictions).squeeze()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "機械学習の性能を比較評価する方法はいくつかありますが、単純に、予測値と実際の値を比較しましょう。今回は、顧客が離反する `1` と離反しない `0` を予測しますので、この混同行列を作成します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pd.crosstab(index=test_data.iloc[:, 0], columns=np.round(predictions), rownames=['actual'], colnames=['predictions'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_注意点, アルゴリズムにはランダムな要素があるので結果は必ずしも一致しません._\n", "\n", "48人の離反者がいて、それらの39名 (true positives) を正しく予測できました。そして、4名の顧客は離反すると予測しましたが、離反していません (false positives)。9名の顧客は離反しないと予測したにもかかわらず離反してしまいました (false negatives)。\n", "\n", "重要な点として、離反するかどうかを `np.round()` という関数で、しきい値0.5で判断しています。`xgboost` が出力する値は0から1までの連続値で、それらを離反する `1` と 離反しない `0` に分類します。しかし、その連続値 (離反する確率) が示すよりも、顧客の離反というのは損害の大きい問題です。つまり離反する確率が低い顧客も、しきい値を0.5から下げて、離反するとみなす必要があるかもしれません。もちろんこては、false positives (離反すると予測したけど離反しなかった)を増やすと思いますが、 true positives (離反すると予測して離反した) を増やし、false negatives (離反しないと予測して離反した)を減らせます。\n", "\n", "直感的な理解のため、予測結果の連続値をみてみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.hist(predictions)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "連続値は0から1まで歪んでいますが、0.1から0.9までの間で、しきい値を調整するにはちょうど良さそうです。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例えば、しきい値を0.5から0.3まで減らしてみたとき、true positives は 1 つ、false positives は 3 つ増え、false negatives は 1 つ減りました。全体からみると小さな値ですが、全体の6-10%の顧客が、しきい値の変更で、予測結果が変わりました。ここで5名にインセンティブを与えることによって、インセンティブのコストが掛かりますが、3名の顧客を引き止めることができるかもしれません。\n", "つまり、最適な閾値を決めることは、実世界の問題を機械学習で解く上で重要なのです。これについてもう少し広く議論し、仮説的なソリューションを考えたいと思います。\n", "\n", "### 5-2.推論エラーのコスト\n", "\n", "2値分類の問題においては、しきい値に注意しなければならないという、似たような状況に直面することが多いです。それ自体は問題ではありません。もし、出力の連続値が2クラスで完全に別れていれば、MLを使うことなく単純なルールで解くことができると考えられます。\n", "\n", "重要なこととして、MLモデルを正版環境に導入する際、モデルが false positives と false negatives に誤って入れたときのコストがあげられます。しきい値の選択は4つの指標に影響を与えます。4つの指標に対して、ビジネス上の相対的なコストを考える必要があるでしょう。\n", "\n", "#### コストの割当\n", "\n", "携帯電話会社の離反の問題において、コストとはなんでしょうか?コストはビジネスでとるべきアクションに結びついています。いくつかの仮定をおいてみましょう。\n", "\n", "まず、true negatives のコストとして \\$0 を割り当てます。満足しているお客様を正しく認識できていれば何も実施しません。\n", "\n", "false negatives が一番問題で、なぜなら、離反していく顧客を正しく予測できないからです。顧客を失えば、再獲得するまでに多くのコストを払う必要もあり、例えば逸失利益、広告コスト、管理コスト、販売管理コスト、電話の購入補助金などがあります。インターネットを簡単に検索してみると、そのようなコストは数百ドルとも言われ、ここでは `$500` としましょう。これが false negatives に対するコストです。\n", "\n", "最後に、離反していくと予測された顧客に `$100` のインセンティブを与えることを考えましょう。\n", "携帯電話会社がそういったインセンティブを提供するなら、2回くらいは離反の前に考え直すかもしれません。これは true positive と false negative のコストになります。false positives の場合 (顧客は満足していて、モデルが誤って離反しそうと予測した場合)、 `$100` のインセンティブは捨てることになります。その `$100` を効率よく消費してしまうかもしれませんが、優良顧客へのロイヤリティを増やすという意味では悪くないかもしれません。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5-3.最適な閾値を探す\n", "\n", "false negatives が false positives よりもコストが高いことは説明しました。そこで、顧客の数ではなく、コストを最小化するように、しきい値を最適化することを考えましょう。コストの関数は以下のようなものになります。\n", "\n", "```txt\n", "$500 * FN(C) + $0 * TN(C) + $100 * FP(C) + $100 * TP(C)\n", "```\n", "\n", "FN(C) は false negative の割合で、しきい値Cの関数です。同様にTN, FP, TP も用意します。この関数の値が最小となるようなしきい値Cを探します。\n", "最も単純な方法は、候補となる閾値で何度もシミュレーションをすることです。以下では100個の値に対してループで計算を行います。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cutoffs = np.arange(0.01, 1, 0.01)\n", "costs = []\n", "\n", "for c in cutoffs:\n", " _predictions = pd.Categorical(np.where(predictions > c, 1, 0), categories=[0, 1])\n", " matrix_a = np.array([[0, 100], [500, 100]])\n", " matrix_b = pd.crosstab(index=test_data.iloc[:, 0], columns=_predictions, dropna=False)\n", " costs.append(np.sum(np.sum(matrix_a * matrix_b)))\n", "\n", "costs = np.array(costs)\n", "plt.plot(cutoffs, costs)\n", "plt.show()\n", "print('Cost is minimized near a cutoff of:', cutoffs[np.argmin(costs)], 'for a cost of:', np.min(costs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6.エンドポイントの削除\n", "\n", "エンドポイントは起動したままだとコストがかかります。不要な場合は削除します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xgb_predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }