{ "cells": [ { "cell_type": "code", "execution_count": null, "source": [ "! pip install -r requirements.txt" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 1_Render Shap Value\n", "Note, due to shap deep explainer's ongoing bug with Tensorflow 2.0+ version, we had to revert the model training framework version back to TF1.15. " ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "bucket = 'pop-modeling-'\n", "prefix = 'sagemaker/synthea'\n", "\n", "RANDOM_STATE = 2021\n", "\n", "# Define IAM role\n", "import boto3\n", "import re\n", "from sagemaker import get_execution_role\n", "import sagemaker\n", "sagemaker_session = sagemaker.Session()\n", "\n", "role = get_execution_role()\n", "s3 = boto3.resource('s3')\n", "my_session = boto3.session.Session()\n", "\n", "import sys\n", "sys.path.append('ml')\n", "from data import load_trained_tf_model, read_csv_s3, download_csv_s3, load_embedding_matrix_trained_model, pad_sequences, get_csv_output_from_s3\n", "from config import MAX_LENGTH, EMBEDDING_DIM\n", "import shap\n", "import numpy as np" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "embedding_matrix, idx_to_code_map, code_to_idx_map, unkown_idx = load_embedding_matrix_trained_model(\n", " model_name='data/W2V_event_dim100_win100.model', embedding_dim=EMBEDDING_DIM)\n", "train_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','train'), 'train.csv')\n", "train_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','train'), 'train-static.csv')\n", "valid_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','validation'), 'validation.csv')\n", "valid_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','validation'), 'valid-static.csv')\n", "test_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','test'), 'test.csv')\n", "test_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','test'), 'test-static.csv')" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "train_encoded_seq = train_input['events'].apply(lambda x: [code_to_idx_map.get(i, unkown_idx) for i in x.split(' ')])\n", "X_train = pad_sequences(train_encoded_seq, maxlen=MAX_LENGTH, padding='pre')\n", "X_train_static = train_static_input.values\n", "test_encoded_seq = test_input['events'].apply(lambda x: [code_to_idx_map.get(i, unkown_idx) for i in x.split(' ')])\n", "X_test = pad_sequences(test_encoded_seq, maxlen=MAX_LENGTH, padding='pre')\n", "X_test_static = test_static_input.values" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "test_output = get_csv_output_from_s3('s3://{}/{}/{}'.format(bucket, prefix, 'sagemaker/batch-transform'), 'test.jsonl.out')\n", "y_prob = np.array(test_output).squeeze()" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "TRAINED_MODEL_JOB_NAME = ''\n", "tf_model = load_trained_tf_model(sagemaker_session, TRAINED_MODEL_JOB_NAME)\n", "explainer = shap.DeepExplainer(tf_model, [X_train, X_train_static])" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "shap_values = explainer.shap_values([X_test[10:11], X_test_static[10:11]])\n", "shap.initjs()\n", "x_test_codes = np.stack([np.array(list(map(lambda x: idx_to_code_map.get(x, \"UNK\"), X_test[i]))) for i in range(10)])\n", "explainer_plot = shap.force_plot(explainer.expected_value[0], shap_values[0][0], x_test_codes[0])\n", "shap.save_html('demo/shap.html', explainer_plot)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 2_Create More Visual Components" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import plotly.express as px\n", "import json, urllib\n", "from plotly.offline import plot\n", "import plotly.figure_factory as ff\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import urllib.request\n", "from bs4 import BeautifulSoup\n", "import plotly.graph_objects as go\n", "import pandas as pd\n", "import htmlmin\n", "\n", "import boto3\n", "import json\n", "s3 = boto3.resource('s3')\n", "from demo.config import sunburst_html, dashboard_style" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "* Plotly Sankey Chart" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "fig = go.Figure(data=[go.Sankey(\n", " node = dict(\n", " pad = 15,\n", " thickness = 20,\n", " line = dict(color = \"black\", width = 0.5),\n", " label = [\"Chronic Kidney Disease\", \"Hypertension\", \"Diabetes\", \"LOW-Risk\", \"MEDIUM-Risk\", \"HIGH-Risk\"]\n", " ),\n", " link = dict(\n", " source = [0, 0, 1, 1, 1, 2, 2, 2], \n", " target = [1, 2, 3, 4, 5, 3, 4, 5],\n", " value = [70, 30, 25, 15, 30, 40, 20, 10]\n", " ))])\n", "fig.write_html(\"demo/flow_chart.html\", include_plotlyjs=False)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "* Convert CSV result to HTML table" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "prediction_table_html = pd.read_csv('demo/dashboard.csv').to_html(\n", " index=False, \n", " table_id='prediction_table',\n", " classes = \"table-responsive table-striped table-hover table-sm\"\n", ")" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 3_Assemble Everything Together" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "with open('demo/flow_chart.html', \"r\") as f:\n", " html = f.read()\n", "parsed_html = BeautifulSoup(html)\n", "body = parsed_html.find('body')\n", "sankey_html = body.findChildren(recursive=False)[0]\n", "\n", "with open('demo/shap.html', \"r\") as f:\n", " html = f.read()\n", "parsed_html = BeautifulSoup(html)\n", "body = parsed_html.find('body')\n", "shap_head = parsed_html.find('head').findChildren(recursive=False)[1]\n", "shap_html = body.findChildren(recursive=False)\n", "\n", "# Get Shap HTML\n", "\n", "message = f\"\"\"\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " Dashboard\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " {shap_head}\n", " \n", " \n", "\n", "\n", " \n", " \n", "
\n", "

Dashboard

\n", " {prediction_table_html}\n", "
\n", " \n", " \n", "
\n", "
\n", "
\n", "

Physician Level Patient Flow

\n", " {sankey_html}\n", "
\n", "
\n", "
\n", "
\n", "

Patient Outcome Indicator

\n", " {shap_html[0]}\n", " {shap_html[1]}\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", " {sunburst_html}\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", "\"\"\"\n", "mined_html = htmlmin.minify(message, remove_comments=True, remove_empty_space=True, remove_all_empty_space=True)\n", "with open('demo/index.html','w') as f:\n", " f.write(mined_html)" ], "outputs": [], "metadata": {} } ], "metadata": { "kernelspec": { "display_name": "conda_tensorflow_p36", "language": "python", "name": "conda_tensorflow_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }