# Using Amazon Sagemaker Clarify to explain Decision Support in Hospital Triage

## Acute Care Clinical Context:

Decision Support at admission time can be especially valuable for prioritization of resources in an acute care clinical setting such as a hosptial. These critical resources come in the form of doctors and nurses, as well as specialized beds, such as ones in the intensive care units. These place limits on the overall capacity of the hospital to treat patients.

Hosptials can more effectively use these resources by predicting the following:
diagnoses at discharge, procedures performed, in-hospital mortality and length-of-stay prediction

Novel approaches in NLP, such as Bidirectional Encoder Representations from Transformers (BERT) models, have allowed for inference on clinical data, and specifically notes, at an accuracy level not attainable a number of years ago. These advances make predicting key clinical indicators from notes data, and applying them in the real world, much more achieveable.

The following references articulate how these indicators have been developed and are being used:

1) "Clinical Outcome Prediction from Admission Notes using Self-Supervised Knowledge Integration" 
    - https://aclanthology.org/2021.eacl-main.75.pdf

2) "Prediction of emergency department patient disposition based on natural language processing of triage notes"
    - https://pubmed.ncbi.nlm.nih.gov/31445253/    

3) Application of Machine Learning in Intensive Care Unit (ICU) Settings Using MIMIC Dataset: Systematic Review
    - https://www.amjmed.com/article/S0002-9343(20)30688-4/abstract

## Overview of the Notebook:

The intent of this notebook is to provide a practical guide for data scientists, and machine learning engineers to collaborate with clinicians, and to support real implementations of clinical indicator predictions. As such, explainability of the algorithms is required.

Advances in NLP algorithms, as in the studies above, have made predicting clinical indicators more accurate, yet in order to effectively use machine learning models in a production setting, clinicians also need more insight into how these models work. They need to know that these algorithms make clinical sense before going to production. Clinicians and data scientists, need a way to evaluate realiablility, and explainability of models over time, as more data continues to be evaluated, and machine learning models are retrained.

This notebook will take one of these clinical triage indicators, in-hospital mortality, and show how AWS services and infrastructure, along with pre-trained HuggingFace BERT models, can be used to train a binary classifier on text data, estimate a threshold value for triage, and then use Amazon Sagemaker Clarify to explain what admission note text is supporting the recommendations the algorithm is making.

In this notebook we use the HuggingFace BERT Model - `bigbird-base-mimic-mortality` (https://huggingface.co/mnaylor/bigbird-base-mimic-mortality). According to the publisher this is a fine-tuned version of Google's base BigBird model with MIMIC admission notes. This model seeks to predict whether a certain patient will expire within a given ICU stay, based on the text available upon admission. This is the pre-trained BERT model we will use in this notebook in order to demonstrate how NLP can be used to create a performant binary classifier for use in a clinical setting.

## Setup
We recommend you use `Python 3 (Data Science)` kernel on SageMaker Studio or `conda_python3` kernel on SageMaker Notebook Instance.

### Install dependencies

First let us Upgrade the envitoment to the specific versions of sagemaker and huggingface libraries

In [2]:
%pip install "sagemaker==2.116.0" "huggingface_hub==0.10.1" --upgrade --quiet


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
transformers 4.6.1 requires huggingface-hub==0.0.8, but you have huggingface-hub 0.10.1 which is incompatible.
datasets 1.6.2 requires huggingface-hub<0.1.0, but you have huggingface-hub 0.10.1 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


Then let's make sure that the specific sagemaker version is loaded correctly. 

In [16]:
import sagemaker
import pandas as pd
import boto3
import pprint
import os

assert sagemaker.__version__ >= "2.116.0"

Upgrade the SageMaker Python SDK, and captum is used to visualize the feature attributions.

In [17]:
%pip install sagemaker --upgrade
%pip install boto3 --upgrade
%pip install botocore --upgrade

Collecting sagemaker
  Downloading sagemaker-2.134.1.tar.gz (673 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m673.4/673.4 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: sagemaker
  Building wheel for sagemaker (setup.py) ... [?25ldone
[?25h  Created wheel for sagemaker: filename=sagemaker-2.134.1-py2.py3-none-any.whl size=910984 sha256=06c8998082a99b9fbfb373a753cd8d99e821ab7f85be554b3deeb4613f7e760e
  Stored in directory: /root/.cache/pip/wheels/f1/b7/ad/996ee655fd473eac12f2316862071a592872d4fb5771193749
Successfully built sagemaker
Installing collected packages: sagemaker
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.116.0
    Uninstalling sagemaker-2.116.0:
      Successfully uninstalled sagemaker-2.116.0
Successfully installed sagemaker-2.134.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip availab

### Download the model 

First lets download the model from huggigface so that we cand deploy the model in sagemaker. As mentioned above, we are going to download `bigbird-base-mimic-mortality` model from hugging face hub. 

In [4]:
repository = "mnaylor/bigbird-base-mimic-mortality"
model_id = repository.split("/")[-1]

To download directly from the hugging face model hub we can use git clone command.

In [5]:
!git lfs install
!git clone https://huggingface.co/$repository

git: 'lfs' is not a git command. See 'git --help'.

The most similar command is
	log
Cloning into 'bigbird-base-mimic-mortality'...
remote: Enumerating objects: 19, done.[K
remote: Total 19 (delta 0), reused 0 (delta 0), pack-reused 19[K
Unpacking objects: 100% (19/19), done.


This will download the model to a directory called `bigbird-base-mimic-mortality`.

In [21]:
assert os.path.exists("./{}".format(model_id)) == True

### Create a custom inference script

In order to wrap this model and use in the Sagemaker inference we are going to create custome `inference.py`. The Hugging Face Inference Toolkit allows the user to override the default methods of the `HuggingFaceHandlerService`.

The custom module can override the following methods:

* `model_fn(model_dir)` overrides the default method for loading a model. 
* `input_fn(input_data, content_type)` overrides the default method for pre-processing.
* `predict_fn(processed_data, model)` overrides the default method for predictions.
* `output_fn(prediction, accept)` overrides the default method for post-processing. (We are not overriding this function in this notebook.)

Following code will create a fodler called `code` and create a custom 'inference.py' script by overriding above methods.

In [12]:
!mkdir $model_id/code

In [28]:
%%writefile $model_id/code/inference.py

import numpy as np
import pandas as pd
import torch
from io import StringIO
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BigBirdTokenizer, BigBirdForSequenceClassification
from typing import Any, Dict, List
import os
import traceback
import json

MODEL_NAME = "mnaylor/bigbird-base-mimic-mortality"

def model_fn(model_dir: str) -> Dict[str, Any]:
    
    try :
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        return tokenizer, model
    except Exception as e :
        print("[Custom] Error Occured whicle loading the model.")
        traceback.print_exc()
        raise e


def predict_fn(input_data: List, torkenizer_model: tuple) -> np.ndarray:
    """
    Apply model to the incoming request
    """
    try :
        print("[Custom] input data is [{}], [{}]".format(type(input_data), input_data))
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        tokenizer, huggingface_model = torkenizer_model
        encoded_input = tokenizer(input_data, return_tensors="pt").to(device)

        print("[Custom] inputs are [{}]".format(encoded_input))

        with torch.no_grad():
            output = huggingface_model(**encoded_input)
            prediction = torch.nn.Softmax(dim=1)(output.logits).detach().cpu().numpy()[:, 1]
            print("[Custom] output is [{}]".format(prediction))
            return prediction
    except Exception as e :
        print("[Custom] Error Occured while predicting.")
        traceback.print_exc()
        raise e
    
    
def input_fn(request_body: str, request_content_type: str) -> List[str]:
    """
    Deserialize and prepare the prediction input
    """
    try :
        print("[Custom] Request is [{}] with content type [{}]".format(request_body, request_content_type))

        if request_content_type == "text/csv":
            # We have a single column with the text.
            sentences = list(pd.read_csv(StringIO(request_body), header=None).values[:, 0].astype(str))
        else:
            raise ValueError("Invalid content type [{}]".format(request_content_type))
        return sentences
    except Exception as e :
        print("[Custom] Error Occured while reading the input.")
        traceback.print_exc()
        raise e

# def output_fn(predictions, accept):
#     print("[Custom] Prediction output type is [{}] [{}]".format(accept, predictions))
#     #res = predictions.astype(np.uint8)
#     res = json.dumps({"preds" : predictions.tolist()})
#     return res


Overwriting bigbird-base-mimic-mortality/code/inference.py


Also this script requires some addional libraries. We will mention them in the `requirments.txt` file under the `code` directory.


In [29]:
%%writefile $model_id/code/requirements.txt

pandas
sentencepiece==0.1.97
transformers==4.18.0

Overwriting bigbird-base-mimic-mortality/code/requirements.txt


Let's verify that both the files exists

In [30]:
assert os.path.exists("./{}/code/inference.py".format(model_id)) == True
assert os.path.exists("./{}/code/requirements.txt".format(model_id)) == True

Then lets create a model.tar.gz archive with all the model artifacts and the inference.py script.

In [33]:
import tarfile
file_name = "hf_model.tar.gz"
with tarfile.open(file_name, mode="w:gz") as archive:
    archive.add(model_id, recursive=True)
    
assert os.path.exists(file_name.format(model_id)) == True

## Deploying the Huggingface model on SageMaker

### Set configurations

In [34]:
sess = sagemaker.Session()
sess = sagemaker.Session(default_bucket=sess.default_bucket())
sagemaker_client = boto3.client("sagemaker")
region = sess.boto_region_name
bucket = sess.default_bucket()
prefix = "sagemaker/DEMO-sagemaker-clarify-text"

# Define the IAM role
role = sagemaker.get_execution_role()

### Upload the hf_model.tar.gz to S3

In [35]:
model_path_s3 = sess.upload_data(path="hf_model.tar.gz", key_prefix=prefix)
model_path_s3

's3://sagemaker-us-east-1-721929407510/sagemaker/DEMO-sagemaker-clarify-text/hf_model.tar.gz'

Can we create a custom HuggingfaceModel class. This class will be used to create model object.

In [37]:
from sagemaker.huggingface import HuggingFaceModel

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data = model_path_s3,
    transformers_version='4.6.1',
    pytorch_version='1.7.1',
    py_version='py36',
    role=role,
    source_dir = "./{}/code".format(model_id),
    entry_point = "inference.py"
)

Define the instace type that we are going to deploy this model.

In [39]:
instance_type = "ml.g4dn.xlarge"
container_def = huggingface_model.prepare_container_def(instance_type=instance_type)
container_def

{'Image': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04',
 'Environment': {'SAGEMAKER_PROGRAM': 'inference.py',
  'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code',
  'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
  'SAGEMAKER_REGION': 'us-east-1'},
 'ModelDataUrl': 's3://sagemaker-us-east-1-721929407510/huggingface-pytorch-inference-2023-02-23-06-03-45-911/model.tar.gz'}

## Create model

The following parameters are required to create a SageMaker model:

* `ExecutionRoleArn`: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deployment

* `ModelName`: name of the SageMaker model.

* `PrimaryContainer`: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.




In [44]:
model_name = "hospital-triage-model"

sagemaker_client.create_model(
    ExecutionRoleArn=role,
    ModelName=model_name,
    PrimaryContainer=container_def,
)
print(f"Model created: {model_name}")

Model created: hospital-triage-model


## Create endpoint config
Create an endpoint configuration by calling the create_endpoint_config API. Here, supply the same model_name used in the create_model API call. The create_endpoint_config now supports the additional parameter ClarifyExplainerConfig to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the ShapBaseline parameter) or by a S3 baseline file (the ShapBaselineUri parameter). Please see the developer guide for the optional parameters.

Here we use a special token as the baseline.

In [42]:
baseline = [["<UNK>"]]
print(f"SHAP baseline: {baseline}")

SHAP baseline: [['<UNK>']]


The TextConfig configured with sentence level granularity (When granularity is sentence, each sentence is a feature, and we need a few sentences per review for good visualization) and the language as English.

In [45]:
endpoint_config_name = "hospital-triage-model-ep-config"
csv_serializer = sagemaker.serializers.CSVSerializer()
json_deserializer = sagemaker.deserializers.JSONDeserializer()

sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "MainVariant",
            "ModelName": model_name,
            "InitialInstanceCount": 1,
            "InstanceType": instance_type,
        }
    ],
    ExplainerConfig={
        "ClarifyExplainerConfig": {
            "InferenceConfig": {"FeatureTypes": ["text"]},
            "ShapConfig": {
                "ShapBaselineConfig": {"ShapBaseline": csv_serializer.serialize(baseline)},
                "TextConfig": {"Granularity": "sentence", "Language": "en"},
            },
        }
    },
)

{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint-config/hospital-triage-model-ep-config',
 'ResponseMetadata': {'RequestId': 'ed9a7555-c241-4105-bf0d-2ce2ff0c8fe8',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'ed9a7555-c241-4105-bf0d-2ce2ff0c8fe8',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '112',
   'date': 'Thu, 23 Feb 2023 06:06:03 GMT'},
  'RetryAttempts': 0}}

## Create endpoint
Once you have your model and endpoint configuration ready, use the create_endpoint API to create your endpoint. The endpoint_name must be unique within an AWS Region in your AWS account. The create_endpoint API is synchronous in nature and returns an immediate response with the endpoint status being Creating state.

In [47]:
endpoint_name = "hospital-triage-prediction-endpoint"
sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)

{'EndpointArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint/hospital-triage-prediction-endpoint',
 'ResponseMetadata': {'RequestId': '9b20dd03-bb87-4b1b-b4fb-30b9bce57f45',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '9b20dd03-bb87-4b1b-b4fb-30b9bce57f45',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '103',
   'date': 'Thu, 23 Feb 2023 06:06:22 GMT'},
  'RetryAttempts': 0}}

Wait for the endpoint to be in `InService` state


In [48]:
sess.wait_for_endpoint(endpoint_name)

-------!

{'EndpointName': 'hospital-triage-prediction-endpoint',
 'EndpointArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint/hospital-triage-prediction-endpoint',
 'EndpointConfigName': 'hospital-triage-model-ep-config',
 'ProductionVariants': [{'VariantName': 'MainVariant',
   'DeployedImages': [{'SpecifiedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04',
     'ResolvedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference@sha256:1d2383a5e52c26db3d2262742d211b979b170fa30a1855302a022e0b9018e6c6',
     'ResolutionTime': datetime.datetime(2023, 2, 23, 6, 6, 23, 451000, tzinfo=tzlocal())}],
   'CurrentWeight': 1.0,
   'DesiredWeight': 1.0,
   'CurrentInstanceCount': 1,
   'DesiredInstanceCount': 1}],
 'EndpointStatus': 'InService',
 'CreationTime': datetime.datetime(2023, 2, 23, 6, 6, 22, 751000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2023, 2, 23, 6, 9, 46, 

## Test endpoint without any explanations

Use the `EnableExplanations` parameter to disable the explanations for this request.



In [80]:
sagemaker_runtime_client = boto3.client("sagemaker-runtime")

sample_admission_note = pd.DataFrame(["""Patient is a 25-year-old male with a chief complaint of acute chest pain. 
    Patient reports the pain began suddenly while at work and has been constant since. 
    Patient rates the pain as 8/10 in severity. Patient denies any radiation of pain, shortness of breath, nausea, or vomiting. 
    Patient reports no previous history of chest pain. 
    Vital signs are as follows: blood pressure 140/90 mmH. Heart rate 92 beats per minute. 
    Respiratory rate 18 breaths per minute. Oxygen saturation 96% on room air. 
    Physical examination reveals mild tenderness to palpation over the precordium and clear lung fields. 
    EKG shows sinus tachycardia with no ST-elevations or depressions. """])

response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Accept="text/csv",
    Body=csv_serializer.serialize(sample_admission_note.iloc[:1, :].to_numpy()),
    EnableExplanations="`false`",  # Do not provide explanations
)

pprint.pprint(response)

{'Body': <botocore.response.StreamingBody object at 0x7f4e28b2e250>,
 'ContentType': 'application/json',
 'InvokedProductionVariant': 'MainVariant',
 'ResponseMetadata': {'HTTPHeaders': {'content-length': '100',
                                      'content-type': 'application/json',
                                      'date': 'Thu, 23 Feb 2023 06:40:22 GMT',
                                      'x-amzn-invoked-production-variant': 'MainVariant',
                                      'x-amzn-requestid': '1df6d126-e2c3-4f86-9495-1a494373c5cf'},
                      'HTTPStatusCode': 200,
                      'RequestId': '1df6d126-e2c3-4f86-9495-1a494373c5cf',
                      'RetryAttempts': 0}}


In [81]:
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)

{'explanations': {},
 'predictions': {'content_type': 'text/csv', 'data': '0.014596084\n'},
 'version': '1.0'}


As we can see this is predicted by the model as non-acute case since the probablity is low as `0.015`. But what statements in the admission note used by the model to come for that conclution ? To answer that we cna leverage SageMaker Clarify.

## Explain the predictions using Amazon Sagemaker Clarify

There are expanding business, clinical needs, and legislative regulations that require explanations of why a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.

How does the Kernel SHAP algorithm work? Kernel SHAP algorithm is a local explanation method. That is, it explains each instance or row of the dataset at a time. To explain each instance, it perturbs the features values - that is, it changes the values of some features to a baseline (or non-informative) value, and then get predictions from the model for the perturbed samples. It does this for a number of times per instance (determined by the optional parameter num_samples in SHAPConfig), and computes the importance of each feature based on how the model prediction changed.

We are now extending this functionality to text data. In order to be able to explain text, we need the TextConfig. The TextConfig is an optional parameter of SHAPConfig, which you need to provide if you need explanations for the text features in your dataset. TextConfig in turn requires three parameters:

* `granularity` (required): To explain text features, Clarify further breaks down text into smaller text units, and considers each such text unit as a feature. The parameter granularity informs the level to which Clarify will break down the text: token, sentence, or paragraph are the allowed values for granularity.
* `language` (required): the language of the text features. This is required to tokenize the text to break them down to their granular form.
* `max_top_tokens` (optional): the number of top token attributions that will be shown in the output (we need this because the size of vocabulary can be very big). This is an optional parameter, and defaults to 50.

Kernel SHAP algorithm requires a baseline (also known as background dataset). In case of tabular features, the baseline value/s for a feature is ideally a non-informative or least informative value for that feature. However, for text feature, the baseline values must be the value you want to replace the individual text feature (token, sentence or paragraph) with. For instance, in the example below, we have chosen the baseline values for review_text as <UNK>, and granularity is sentence. Every time a sentence has to replaced in the perturbed inputs, we will replace it with <UNK>.

If baseline is not provided, a baseline is calculated automatically by SageMaker Clarify using K-means or K-prototypes in the input dataset for tabular features. For text features, if baseline is not provided, the default replacement value will be the string <PAD>.

## Test endpoint with explanations

This time we'll invoke the endpoint with explainations enabaled (which is the default setting). 

In [82]:

response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Accept="text/csv",
    Body=csv_serializer.serialize(sample_admission_note.iloc[:1, :].to_numpy())
)

result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)

{'explanations': {'kernel_shap': [[{'attributions': [{'attribution': [-0.13967000051377493],
                                                      'description': {'partial_text': 'Patient '
                                                                                      'is '
                                                                                      'a '
                                                                                      '25-year-old '
                                                                                      'male '
                                                                                      'with '
                                                                                      'a '
                                                                                      'chief '
                                                                                      'complaint '
                                                     

You can see the `kernel shap` values are returned with the reponse. To interpret this at the sentence level let's use some visualizations.




We'll define some utility functions to get the visualizations of the SHAP values.

In [64]:
import csv
import numpy as np
from captum.attr import visualization

def visualization_record(
    attributions,  # list of attributions for the tokens
    text,  # list of tokens
    pred,  # the prediction value obtained from the endpoint
    delta,
    true_label,  # the true label from the dataset
    normalize=True,  # normalizes the attributions so that the max absolute value is 1. Yields stronger colors.
    max_frac_to_show=0.05,  # what fraction of tokens to highlight, set to 1 for all.
    match_to_pred=False,  # whether to limit highlights to red for negative predictions and green for positive ones.
    # By enabling `match_to_pred` you show what tokens contribute to a high/low prediction not those that oppose it.
):
    
    if normalize:
        attributions = attributions / max(max(attributions), max(-attributions))
    if max_frac_to_show is not None and max_frac_to_show < 1:
        num_show = int(max_frac_to_show * attributions.shape[0])
        sal = attributions
        if pred < 0.5:
            sal = -sal
        if not match_to_pred:
            sal = np.abs(sal)
        top_idxs = np.argsort(-sal)[:num_show]
        mask = np.zeros_like(attributions)
        mask[top_idxs] = 1
        attributions = attributions * mask
    return visualization.VisualizationDataRecord(
        attributions,
        pred,
        int(pred > 0.5),
        true_label,
        attributions.sum() > 0,
        attributions.sum(),
        text,
        delta,
    )

def visualize_result(result, all_labels):
    if not result["explanations"]:
        print(f"No Clarify explanations for the record(s)")
        return
    all_explanations = result["explanations"]["kernel_shap"]
    all_predictions = list(csv.reader(result["predictions"]["data"].splitlines()))

    labels = []
    predictions = []
    explanations = []

    for i, expl in enumerate(all_explanations):
        if expl:
            labels.append(all_labels[i])
            predictions.append(all_predictions[i])
            explanations.append(all_explanations[i])

    attributions_dataset = [
        np.array([attr["attribution"][0] for attr in expl[0]["attributions"]])
        for expl in explanations
    ]
    tokens_dataset = [
        np.array([attr["description"]["partial_text"] for attr in expl[0]["attributions"]])
        for expl in explanations
    ]

    # You can customize the following display settings
    normalize = True
    max_frac_to_show = 1
    match_to_pred = False
    vis = []
    for attr, token, pred, label in zip(attributions_dataset, tokens_dataset, predictions, labels):
        vis.append(
            visualization_record(
                -attr, token, float(pred[0]), 0.0, label, normalize, max_frac_to_show, match_to_pred
            )
        )
    _ = visualization.visualize_text(vis)

def invoke_visualize(test_admission_notes, true_label):
    response = sagemaker_runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="text/csv",
        Accept="text/csv",
        Body=csv_serializer.serialize(test_admission_notes.iloc[:1, :].to_numpy())
    )
    result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
    visualize_result(result, [true_label])

### Explain the predictions for a non-acute admission note

In [83]:
invoke_visualize(sample_admission_note, 0)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.01),True,2.39,"Patient is a 25-year-old male with a chief complaint of acute chest pain. Patient reports the pain began suddenly while at work and has been constant since. Patient rates the pain as 8/10 in severity. Patient denies any radiation of pain, shortness of breath, nausea, or vomiting. Patient reports no previous history of chest pain. Vital signs are as follows: blood pressure 140/90 mmH. Heart rate 92 beats per minute. Respiratory rate 18 breaths per minute. Oxygen saturation 96% on room air. Physical examination reveals mild tenderness to palpation over the precordium and clear lung fields. EKG shows sinus tachycardia with no ST-elevations or depressions."
,,,,


### Explain the predictions for a acute admission note

In [66]:
test_admission_notes_accute = pd.DataFrame(
    ["""Patient is a 72-year-old female with a chief complaint of severe sepsis and septic shock. 
    Patient reports a fever, chills, and weakness for the past 3 days, as well as decreased urine output and confusion. 
    Patient has a history of chronic obstructive pulmonary disease (COPD) and a recent hospitalization for pneumonia. 
    Vital signs are as follows: blood pressure 80/40 mmHg. Heart rate 130 beats per minute. Respiratory rate 30 breaths per minute. 
    Oxygen saturation 82% on 4L of oxygen via nasal cannula. 
    Physical examination reveals diffuse erythema and warmth over the lower extremities and positive findings for sepsis such as altered mental status, 
    tachycardia, and tachypnea. Blood cultures were taken and antibiotic therapy was started with appropriate coverage."""]
    )

invoke_visualize(test_admission_notes_accute, [1])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
[1],1 (0.75),False,-1.26,"atient is a 72-year-old female with a chief complaint of severe sepsis and septic shock. Patient reports a fever, chills, and weakness for the past 3 days, as well as decreased urine output and confusion. Patient has a history of chronic obstructive pulmonary disease (COPD) and a recent hospitalization for pneumonia. Vital signs are as follows: blood pressure 80/40 mmHg. Heart rate 130 beats per minute. Respiratory rate 30 breaths per minute. Oxygen saturation 82% on 4L of oxygen via nasal cannula. Physical examination reveals diffuse erythema and warmth over the lower extremities and positive findings for sepsis such as altered mental status, tachycardia, and tachypnea. Blood cultures were taken and antibiotic therapy was started with appropriate coverage."
,,,,


### Clean Up

Clean up the deployed models to not incur further charges

In [None]:
huggingface_model.delete_model()
predictor = sagemaker.Predictor(endpoint_name="triage-prediction-endpoint")
predictor.delete_endpoint()