{ "cells": [ { "cell_type": "markdown", "id": "df276921", "metadata": {}, "source": [ "# Churn Prediction with Text and Interpretability" ] }, { "cell_type": "markdown", "id": "1c1f7791", "metadata": {}, "source": [ "This notebook runs the entire churn prediction pipeline from data preparation to model evaluation and interpretation.\n", "\n", "Alternatively, everything can be run from the terminal as well (see README.md).\n", "\n", "Prerequisite: Dataset has been created (see README.md)." ] }, { "cell_type": "markdown", "id": "032bb6ce", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "a32ade35", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from matplotlib import pyplot as plt\n", "\n", "os.chdir(\"../scripts\")\n", "\n", "import preprocess\n", "import train\n", "import interpret" ] }, { "cell_type": "code", "execution_count": 3, "id": "9d500e76", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "4b8b1336", "metadata": {}, "source": [ "### Load and Prepare the Data" ] }, { "cell_type": "code", "execution_count": 4, "id": "94c30e10", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
churnchat_logstateaccount_lengtharea_codeinternational_planvoice_mail_plannumber_vmail_messagestotal_day_minutestotal_day_calls...total_eve_minutestotal_eve_callstotal_eve_chargetotal_night_minutestotal_night_callstotal_night_chargetotal_intl_minutestotal_intl_callstotal_intl_chargenumber_customer_service_calls
0noCustomer: Well, the only thing that I'm consid...CT134area_code_408nono0177.291...228.710519.44194.31138.748.932.402
1yesCustomer: Well, I just want to be able to canc...WV78area_code_408nono0226.388...306.28126.03200.91209.047.8112.111
2noCustomer: I would like data.\\nTelCom Agent: Ok...IN88area_code_415nono0183.593...170.58014.49193.8888.728.352.243
\n", "

3 rows × 21 columns

\n", "
" ], "text/plain": [ " churn chat_log state \\\n", "0 no Customer: Well, the only thing that I'm consid... CT \n", "1 yes Customer: Well, I just want to be able to canc... WV \n", "2 no Customer: I would like data.\\nTelCom Agent: Ok... IN \n", "\n", " account_length area_code international_plan voice_mail_plan \\\n", "0 134 area_code_408 no no \n", "1 78 area_code_408 no no \n", "2 88 area_code_415 no no \n", "\n", " number_vmail_messages total_day_minutes total_day_calls ... \\\n", "0 0 177.2 91 ... \n", "1 0 226.3 88 ... \n", "2 0 183.5 93 ... \n", "\n", " total_eve_minutes total_eve_calls total_eve_charge total_night_minutes \\\n", "0 228.7 105 19.44 194.3 \n", "1 306.2 81 26.03 200.9 \n", "2 170.5 80 14.49 193.8 \n", "\n", " total_night_calls total_night_charge total_intl_minutes \\\n", "0 113 8.74 8.9 \n", "1 120 9.04 7.8 \n", "2 88 8.72 8.3 \n", "\n", " total_intl_calls total_intl_charge number_customer_service_calls \n", "0 3 2.40 2 \n", "1 11 2.11 1 \n", "2 5 2.24 3 \n", "\n", "[3 rows x 21 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('../data/churn_dataset.csv')\n", "df.head(3)" ] }, { "cell_type": "code", "execution_count": null, "id": "ede78f6b", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = preprocess.prep_data(df, use_existing=False, test_size=0.33)" ] }, { "cell_type": "code", "execution_count": 27, "id": "5d0f6d4f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((2233, 841), (1100, 841), (2233, 1), (1100, 1))" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape, X_test.shape, y_train.shape, y_test.shape" ] }, { "cell_type": "markdown", "id": "c477f751", "metadata": {}, "source": [ "### Train and Evaluate the Model" ] }, { "cell_type": "code", "execution_count": 28, "id": "0d68bf05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ec2-user/SageMaker/churn_test/scripts\n" ] } ], "source": [ "!pwd" ] }, { "cell_type": "code", "execution_count": 29, "id": "eb2e2878", "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "starting epoch: 1\n", "Train Epoch: 1, train-auc-score: 0.9561\n", "test_auc_score: 0.9422\n", "starting epoch: 2\n", "Train Epoch: 2, train-auc-score: 0.9571\n", "test_auc_score: 0.9453\n", "starting epoch: 3\n", "Train Epoch: 3, train-auc-score: 0.9602\n", "test_auc_score: 0.9480\n", "starting epoch: 4\n", "Train Epoch: 4, train-auc-score: 0.9594\n", "test_auc_score: 0.9467\n", "starting epoch: 5\n", "Train Epoch: 5, train-auc-score: 0.9628\n", "test_auc_score: 0.9529\n", "starting epoch: 6\n", "Train Epoch: 6, train-auc-score: 0.9711\n", "test_auc_score: 0.9555\n", "starting epoch: 7\n", "Train Epoch: 7, train-auc-score: 0.9756\n", "test_auc_score: 0.9586\n", "starting epoch: 8\n", "Train Epoch: 8, train-auc-score: 0.9804\n", "test_auc_score: 0.9598\n", "starting epoch: 9\n", "Train Epoch: 9, train-auc-score: 0.9810\n", "test_auc_score: 0.9618\n", "starting epoch: 10\n", "Train Epoch: 10, train-auc-score: 0.9644\n", "test_auc_score: 0.9552\n", "saving scores\n", "saving model\n" ] } ], "source": [ "# train the model\n", "train.train(\n", " X=X_train,\n", " y=y_train,\n", " X_test=X_test,\n", " y_test=y_test\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "id": "294bd939", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot training stats\n", "train.plot_train_stats()" ] }, { "cell_type": "code", "execution_count": 31, "id": "75e22bb5", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot pr curve\n", "train.plot_pr_curve(X_test, y_test)" ] }, { "cell_type": "markdown", "id": "fa34b9d3", "metadata": {}, "source": [ "### Interpret the Model" ] }, { "cell_type": "markdown", "id": "2ca793cf", "metadata": {}, "source": [ "#### Categorical and Numerical Features" ] }, { "cell_type": "code", "execution_count": 32, "id": "5ad792da", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/xgboost/sklearn.py:1146: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].\n", " warnings.warn(label_encoder_deprecation_msg, UserWarning)\n", "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", " return f(*args, **kwargs)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[16:15:16] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "preds_xgb = interpret.train_xgb()" ] }, { "cell_type": "markdown", "id": "9e80253e", "metadata": {}, "source": [ "#### Textual Features (focus on customer chats that result in churn)" ] }, { "cell_type": "code", "execution_count": 9, "id": "7f18dcf5", "metadata": {}, "outputs": [], "source": [ "chats, df_sub = interpret.get_chats(df=df, churn=1, speaker='Customer')" ] }, { "cell_type": "code", "execution_count": 10, "id": "0b65267d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[\"Well, I just want to be able to cancel the contract because I don't think that I want to stay. My local provider has been terrible and I really would like to switch. Sure, I can.\",\n", " \"Well, it's the old TelCom billing system for the last 5 years. I don't trust anymore and I think you should change to the newer billing system. I would like to give you a call back number. Okay, I can see why you need the new billing system, but I don't know if I can do that. I would like to know your cancellation policy.\",\n", " \"Well, I've been getting phone calls from a very good friend who's a TelCom agent and I have told him the same thing and the problem has not been resolved. He has offered me a $20/mo deal but that's not good enough for me because I'm getting $20 out of his pocket. Sure. $60 to cancel for nine months with a $20/mo bonus.\"]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chats[:3]" ] }, { "cell_type": "markdown", "id": "8c73e7dd", "metadata": {}, "source": [ "##### Candidate keywords (POS tagging, lower casing, lemmatization)" ] }, { "cell_type": "code", "execution_count": 11, "id": "3db8f636", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['want', 'able', 'cancel', 'contract', 'think', 'want', 'stay', 'local', 'provider', 'terrible']\n" ] } ], "source": [ "# find candidate keywords\n", "keywords, tokens = interpret.get_keywords(chats)\n", "print(keywords[:10])\n", "\n", "# map keywords to original tokens\n", "keywords_dict = interpret.map_to_orig_tok(keywords, tokens)" ] }, { "cell_type": "markdown", "id": "172eed77", "metadata": {}, "source": [ "##### Relevant keywords (semantic similarity)" ] }, { "cell_type": "code", "execution_count": 12, "id": "07be6aa0", "metadata": {}, "outputs": [], "source": [ "relevant_keywords, simMat = interpret.get_relevant_keywords(\n", " text = chats, \n", " keywords_dict = keywords_dict\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "267a00b0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['voicemail', 'bored', 'spam', 'backlog', 'sick', 'frustrated', 'incompetence', 'overcharge', 'angry', 'disappointed', 'scam', 'lag', 'termination', 'resign', 'nightmare', 'incompetent', 'frustration', 'lagging', 'yesterday', 'friday', 'monday', 'discontinue', 'cheat', 'dead', 'annoyed']\n" ] } ], "source": [ "print(relevant_keywords[:25])" ] }, { "cell_type": "markdown", "id": "3ba0da26", "metadata": {}, "source": [ "##### Impactful kewords (marginal contribution)" ] }, { "cell_type": "code", "execution_count": 14, "id": "8c877937", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 250/250 [11:25<00:00, 2.74s/it]\n" ] } ], "source": [ "# get marginal contribution to prediction for each keyword\n", "marg_contr_df = interpret.perform_ablation(\n", " df = df_sub,\n", " keywords = relevant_keywords,\n", " keywords_dict = keywords_dict\n", ")" ] }, { "cell_type": "markdown", "id": "d43f8aa6", "metadata": {}, "source": [ "##### Create joint metric (semantic similarity + marginal contribution + count)" ] }, { "cell_type": "code", "execution_count": null, "id": "c02d300a", "metadata": {}, "outputs": [], "source": [ "# load from local disc if available\n", "#results_df = pd.read_csv('model/ablation_results.csv')\n", "#results_df = results_df.rename(columns={'Unnamed: 0' : 'keyword'})" ] }, { "cell_type": "code", "execution_count": 15, "id": "2e112e00", "metadata": {}, "outputs": [], "source": [ "results_df = interpret.get_important_keywords(\n", " simMat_df=simMat,\n", " marg_contr_df=marg_contr_df\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "efd90445", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
keywordsimchgcountjoint
0voicemail90.0765530.00669550.628774
1cancel61.187359-0.0813211680.545633
2sick74.789459-0.12724210.538919
3turnover60.896118-0.28663010.533321
4disappointed70.248131-0.09174050.522520
5spam77.940460-0.00002230.506429
6bored78.131271-0.03893210.502213
7unhappy65.601990-0.024782370.493910
8frustrated73.496033-0.00602350.486930
9mistake66.247879-0.09330930.470609
10late65.971649-0.003873180.458023
11error63.123695-0.06560970.448412
12faulty55.486092-0.23981710.448232
13angry70.865860-0.00296730.444018
14backlog74.8085480.00002010.442185
15customer52.072552-0.0069854800.439737
16lag69.912178-0.00465030.436584
17overpay65.573105-0.09384610.429261
18disconnect64.002014-0.010485110.429104
19incompetence73.107452-0.00075410.427229
\n", "
" ], "text/plain": [ " keyword sim chg count joint\n", "0 voicemail 90.076553 0.006695 5 0.628774\n", "1 cancel 61.187359 -0.081321 168 0.545633\n", "2 sick 74.789459 -0.127242 1 0.538919\n", "3 turnover 60.896118 -0.286630 1 0.533321\n", "4 disappointed 70.248131 -0.091740 5 0.522520\n", "5 spam 77.940460 -0.000022 3 0.506429\n", "6 bored 78.131271 -0.038932 1 0.502213\n", "7 unhappy 65.601990 -0.024782 37 0.493910\n", "8 frustrated 73.496033 -0.006023 5 0.486930\n", "9 mistake 66.247879 -0.093309 3 0.470609\n", "10 late 65.971649 -0.003873 18 0.458023\n", "11 error 63.123695 -0.065609 7 0.448412\n", "12 faulty 55.486092 -0.239817 1 0.448232\n", "13 angry 70.865860 -0.002967 3 0.444018\n", "14 backlog 74.808548 0.000020 1 0.442185\n", "15 customer 52.072552 -0.006985 480 0.439737\n", "16 lag 69.912178 -0.004650 3 0.436584\n", "17 overpay 65.573105 -0.093846 1 0.429261\n", "18 disconnect 64.002014 -0.010485 11 0.429104\n", "19 incompetence 73.107452 -0.000754 1 0.427229" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_df.head(20)" ] }, { "cell_type": "markdown", "id": "7036dc83", "metadata": {}, "source": [ "##### Context of keywords" ] }, { "cell_type": "code", "execution_count": 17, "id": "5c0f0d2b", "metadata": {}, "outputs": [], "source": [ "keyword_of_interest = 'spam'" ] }, { "cell_type": "code", "execution_count": 18, "id": "041befbf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Basically, I'm getting a lot of spam calls every day from a guy named Michael who's calling from a really weird number.\n", "TelCom started to flood me with emails and phone calls, spamming me with thousands of phony invoices.\n", "I just got some spam messages last night, and today it's been getting a lot of texts that I \"don't have my SIM card\" and \"I need my SIM card.\n" ] } ], "source": [ "interpret.obtain_context(\n", " chats_list = chats,\n", " keyword = keyword_of_interest\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a294f494", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }