In [None]:
! pip install -r requirements.txt

## 1_Render Shap Value
Note, due to shap deep explainer's ongoing bug with Tensorflow 2.0+ version, we had to revert the model training framework version back to TF1.15. 

In [None]:
%load_ext autoreload
%autoreload 2
bucket = 'pop-modeling-<your-account-id>'
prefix = 'sagemaker/synthea'

RANDOM_STATE = 2021

# Define IAM role
import boto3
import re
from sagemaker import get_execution_role
import sagemaker
sagemaker_session = sagemaker.Session()

role = get_execution_role()
s3 = boto3.resource('s3')
my_session = boto3.session.Session()

import sys
sys.path.append('ml')
from data import load_trained_tf_model, read_csv_s3, download_csv_s3, load_embedding_matrix_trained_model, pad_sequences, get_csv_output_from_s3
from config import MAX_LENGTH, EMBEDDING_DIM
import shap
import numpy as np

In [None]:
embedding_matrix, idx_to_code_map, code_to_idx_map, unkown_idx = load_embedding_matrix_trained_model(
    model_name='data/W2V_event_dim100_win100.model', embedding_dim=EMBEDDING_DIM)
train_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','train'), 'train.csv')
train_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','train'), 'train-static.csv')
valid_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','validation'), 'validation.csv')
valid_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','validation'), 'valid-static.csv')
test_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','test'), 'test.csv')
test_static_input = read_csv_s3(bucket, '{}/{}/{}'.format(prefix, 'sagemaker','test'), 'test-static.csv')

In [None]:
train_encoded_seq = train_input['events'].apply(lambda x: [code_to_idx_map.get(i, unkown_idx) for i in x.split(' ')])
X_train = pad_sequences(train_encoded_seq, maxlen=MAX_LENGTH, padding='pre')
X_train_static = train_static_input.values
test_encoded_seq = test_input['events'].apply(lambda x: [code_to_idx_map.get(i, unkown_idx) for i in x.split(' ')])
X_test = pad_sequences(test_encoded_seq, maxlen=MAX_LENGTH, padding='pre')
X_test_static = test_static_input.values

In [None]:
test_output = get_csv_output_from_s3('s3://{}/{}/{}'.format(bucket, prefix, 'sagemaker/batch-transform'), 'test.jsonl.out')
y_prob = np.array(test_output).squeeze()

In [None]:
TRAINED_MODEL_JOB_NAME = '<your-training-job-id>'
tf_model = load_trained_tf_model(sagemaker_session, TRAINED_MODEL_JOB_NAME)
explainer = shap.DeepExplainer(tf_model, [X_train, X_train_static])

In [None]:
shap_values = explainer.shap_values([X_test[10:11], X_test_static[10:11]])
shap.initjs()
x_test_codes = np.stack([np.array(list(map(lambda x: idx_to_code_map.get(x, "UNK"), X_test[i]))) for i in range(10)])
explainer_plot = shap.force_plot(explainer.expected_value[0], shap_values[0][0], x_test_codes[0])
shap.save_html('demo/shap.html', explainer_plot)

## 2_Create More Visual Components

In [None]:
%load_ext autoreload
%autoreload 2
import plotly.express as px
import json, urllib
from plotly.offline import plot
import plotly.figure_factory as ff
import pandas as pd
import numpy as np

import urllib.request
from bs4 import BeautifulSoup
import plotly.graph_objects as go
import pandas as pd
import htmlmin

import boto3
import json
s3 = boto3.resource('s3')
from demo.config import sunburst_html, dashboard_style

* Plotly Sankey Chart

In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["Chronic Kidney Disease", "Hypertension", "Diabetes", "LOW-Risk", "MEDIUM-Risk", "HIGH-Risk"]
    ),
    link = dict(
      source = [0,   0,  1,  1,   1,  2,  2,  2], 
      target = [1,   2,  3,  4,   5,  3,  4,  5],
      value =  [70, 30,  25, 15,  30, 40, 20, 10]
  ))])
fig.write_html("demo/flow_chart.html", include_plotlyjs=False)

* Convert CSV result to HTML table

In [None]:
prediction_table_html = pd.read_csv('demo/dashboard.csv').to_html(
    index=False, 
    table_id='prediction_table',
    classes = "table-responsive table-striped table-hover table-sm"
)

## 3_Assemble Everything Together

In [None]:
with open('demo/flow_chart.html', "r") as f:
    html = f.read()
parsed_html = BeautifulSoup(html)
body = parsed_html.find('body')
sankey_html = body.findChildren(recursive=False)[0]

with open('demo/shap.html', "r") as f:
    html = f.read()
parsed_html = BeautifulSoup(html)
body = parsed_html.find('body')
shap_head = parsed_html.find('head').findChildren(recursive=False)[1]
shap_html = body.findChildren(recursive=False)

# Get Shap HTML

message = f"""
<!doctype html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
    <meta name="description" content="DashboardDemo">
    <meta name="author" content="shuaicao">

    <title>Dashboard</title>

    <!-- Bootstrap, jQuery, datatables -->
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.22/css/jquery.dataTables.css">
    <link rel="stylesheet" href="https://cdn.datatables.net/responsive/2.2.1/css/responsive.dataTables.min.css">
    <script type="text/javascript" charset="utf8" src="https://cdn.datatables.net/1.10.22/js/jquery.dataTables.js"></script>
    <script> $(document).ready( function () {{$('#prediction_table').DataTable();}} );</script>
    {shap_head}
    <style> {dashboard_style} </style>
  </head>


  <body>
    <script src="https://d3js.org/d3.v3.min.js"></script>
    <div class="position-relative overflow-hidden p-3 p-md-4 m-md-3 text-center bg-light">
        <h1 class="display-4 font-weight-normal">Dashboard</h1>
        {prediction_table_html}
    </div>
         
      
    <div class="d-md-flex flex-md-equal w-100 my-md-3 pl-md-3">
      <div class="w-50 bg-light mr-md-3 pt-3 px-3 pt-md-5 px-md-5 text-center overflow-hidden">
        <div class="my-3 py-3 overflow-hidden">
          <h2 class="display-5">Physician Level Patient Flow</h2>
          {sankey_html}
        </div>
      </div>
      <div class="w-50 bg-light mr-md-3 pt-3 px-3 pt-md-5 px-md-5 text-center overflow-hidden">
        <div class="my-3 p-3 overflow-hidden">
          <h2 class="display-5">Patient Outcome Indicator</h2>
          {shap_html[0]}
          {shap_html[1]}
        </div>
      </div>
    </div>
      
    <div class="bg-light mr-md-3 pt-3 px-3 pt-md-5 px-md-5 text-center overflow-hidden">
        <div id="sunburst" class="bg-light box-shadow mx-auto overflow-hidden" style="width: 90%; height: 1000px; border-radius: 21px 21px 0 0;">
        {sunburst_html}
        </div>
    </div>


    <!-- Bootstrap core JavaScript
    ================================================== -->
    <!-- Placed at the end of the document so the pages load faster -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
  </body>
</html>
"""
mined_html = htmlmin.minify(message, remove_comments=True, remove_empty_space=True, remove_all_empty_space=True)
with open('demo/index.html','w') as f:
    f.write(mined_html)