{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Amazon Sagemaker Clarify to explain Decision Support in Hospital Triage" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Acute Care Clinical Context:\n", "\n", "Decision Support at admission time can be especially valuable for prioritization of resources in an acute care clinical setting such as a hosptial. These critical resources come in the form of doctors and nurses, as well as specialized beds, such as ones in the intensive care units. These place limits on the overall capacity of the hospital to treat patients.\n", "\n", "Hosptials can more effectively use these resources by predicting the following:\n", "diagnoses at discharge, procedures performed, in-hospital mortality and length-of-stay prediction\n", "\n", "Novel approaches in NLP, such as Bidirectional Encoder Representations from Transformers (BERT) models, have allowed for inference on clinical data, and specifically notes, at an accuracy level not attainable a number of years ago. These advances make predicting key clinical indicators from notes data, and applying them in the real world, much more achieveable.\n", "\n", "The following references articulate how these indicators have been developed and are being used:\n", "\n", "1) \"Clinical Outcome Prediction from Admission Notes using Self-Supervised Knowledge Integration\" \n", " - https://aclanthology.org/2021.eacl-main.75.pdf\n", "\n", "2) \"Prediction of emergency department patient disposition based on natural language processing of triage notes\"\n", " - https://pubmed.ncbi.nlm.nih.gov/31445253/ \n", "\n", "3) Application of Machine Learning in Intensive Care Unit (ICU) Settings Using MIMIC Dataset: Systematic Review\n", " - https://www.amjmed.com/article/S0002-9343(20)30688-4/abstract\n", "\n", "## Overview of the Notebook:\n", "\n", "The intent of this notebook is to provide a practical guide for data scientists, and machine learning engineers to collaborate with clinicians, and to support real implementations of clinical indicator predictions. As such, explainability of the algorithms is required.\n", "\n", "Advances in NLP algorithms, as in the studies above, have made predicting clinical indicators more accurate, yet in order to effectively use machine learning models in a production setting, clinicians also need more insight into how these models work. They need to know that these algorithms make clinical sense before going to production. Clinicians and data scientists, need a way to evaluate realiablility, and explainability of models over time, as more data continues to be evaluated, and machine learning models are retrained.\n", "\n", "This notebook will take one of these clinical triage indicators, in-hospital mortality, and show how AWS services and infrastructure, along with pre-trained HuggingFace BERT models, can be used to train a binary classifier on text data, estimate a threshold value for triage, and then use Amazon Sagemaker Clarify to explain what admission note text is supporting the recommendations the algorithm is making.\n", "\n", "In this notebook we use the HuggingFace BERT Model - `bigbird-base-mimic-mortality` (https://huggingface.co/mnaylor/bigbird-base-mimic-mortality). According to the publisher this is a fine-tuned version of Google's base BigBird model with MIMIC admission notes. This model seeks to predict whether a certain patient will expire within a given ICU stay, based on the text available upon admission. This is the pre-trained BERT model we will use in this notebook in order to demonstrate how NLP can be used to create a performant binary classifier for use in a clinical setting." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "We recommend you use `Python 3 (Data Science)` kernel on SageMaker Studio or `conda_python3` kernel on SageMaker Notebook Instance.\n", "\n", "### Install dependencies" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "First let us Upgrade the envitoment to the specific versions of sagemaker and huggingface libraries" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "transformers 4.6.1 requires huggingface-hub==0.0.8, but you have huggingface-hub 0.10.1 which is incompatible.\n", "datasets 1.6.2 requires huggingface-hub<0.1.0, but you have huggingface-hub 0.10.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install \"sagemaker==2.116.0\" \"huggingface_hub==0.10.1\" --upgrade --quiet\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then let's make sure that the specific sagemaker version is loaded correctly. " ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "import pandas as pd\n", "import boto3\n", "import pprint\n", "import os\n", "\n", "assert sagemaker.__version__ >= \"2.116.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Upgrade the SageMaker Python SDK, and captum is used to visualize the feature attributions." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: sagemaker in /opt/conda/lib/python3.7/site-packages (2.116.0)\n", "Collecting sagemaker\n", " Downloading sagemaker-2.134.1.tar.gz (673 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m673.4/673.4 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: attrs<23,>=20.3.0 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (21.4.0)\n", "Requirement already satisfied: boto3<2.0,>=1.26.28 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (1.26.76)\n", "Requirement already satisfied: google-pasta in /opt/conda/lib/python3.7/site-packages (from sagemaker) (0.2.0)\n", "Requirement already satisfied: numpy<2.0,>=1.9.0 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (1.21.6)\n", "Requirement already satisfied: protobuf<4.0,>=3.1 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (3.20.1)\n", "Requirement already satisfied: protobuf3-to-dict<1.0,>=0.1.5 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (0.1.5)\n", "Requirement already satisfied: smdebug_rulesconfig==1.0.1 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (1.0.1)\n", "Requirement already satisfied: importlib-metadata<5.0,>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (4.12.0)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from sagemaker) (21.3)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from sagemaker) (1.3.5)\n", "Requirement already satisfied: pathos in /opt/conda/lib/python3.7/site-packages (from sagemaker) (0.2.9)\n", "Requirement already satisfied: schema in /opt/conda/lib/python3.7/site-packages (from sagemaker) (0.7.5)\n", "Requirement already satisfied: botocore<1.30.0,>=1.29.76 in /opt/conda/lib/python3.7/site-packages (from boto3<2.0,>=1.26.28->sagemaker) (1.29.76)\n", "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.7/site-packages (from boto3<2.0,>=1.26.28->sagemaker) (0.10.0)\n", "Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from boto3<2.0,>=1.26.28->sagemaker) (0.6.0)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata<5.0,>=1.4.0->sagemaker) (3.8.1)\n", "Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata<5.0,>=1.4.0->sagemaker) (4.3.0)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=20.0->sagemaker) (2.4.6)\n", "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from protobuf3-to-dict<1.0,>=0.1.5->sagemaker) (1.14.0)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->sagemaker) (2.8.1)\n", "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->sagemaker) (2019.3)\n", "Requirement already satisfied: dill>=0.3.5.1 in /opt/conda/lib/python3.7/site-packages (from pathos->sagemaker) (0.3.5.1)\n", "Requirement already satisfied: pox>=0.3.1 in /opt/conda/lib/python3.7/site-packages (from pathos->sagemaker) (0.3.1)\n", "Requirement already satisfied: ppft>=1.7.6.5 in /opt/conda/lib/python3.7/site-packages (from pathos->sagemaker) (1.7.6.5)\n", "Requirement already satisfied: multiprocess>=0.70.13 in /opt/conda/lib/python3.7/site-packages (from pathos->sagemaker) (0.70.13)\n", "Requirement already satisfied: contextlib2>=0.5.5 in /opt/conda/lib/python3.7/site-packages (from schema->sagemaker) (0.6.0.post1)\n", "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /opt/conda/lib/python3.7/site-packages (from botocore<1.30.0,>=1.29.76->boto3<2.0,>=1.26.28->sagemaker) (1.26.12)\n", "Building wheels for collected packages: sagemaker\n", " Building wheel for sagemaker (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for sagemaker: filename=sagemaker-2.134.1-py2.py3-none-any.whl size=910984 sha256=06c8998082a99b9fbfb373a753cd8d99e821ab7f85be554b3deeb4613f7e760e\n", " Stored in directory: /root/.cache/pip/wheels/f1/b7/ad/996ee655fd473eac12f2316862071a592872d4fb5771193749\n", "Successfully built sagemaker\n", "Installing collected packages: sagemaker\n", " Attempting uninstall: sagemaker\n", " Found existing installation: sagemaker 2.116.0\n", " Uninstalling sagemaker-2.116.0:\n", " Successfully uninstalled sagemaker-2.116.0\n", "Successfully installed sagemaker-2.134.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: boto3 in /opt/conda/lib/python3.7/site-packages (1.26.76)\n", "Collecting boto3\n", " Downloading boto3-1.26.77-py3-none-any.whl (132 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.7/132.7 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from boto3) (0.6.0)\n", "Collecting botocore<1.30.0,>=1.29.77\n", " Downloading botocore-1.29.77-py3-none-any.whl (10.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.7/site-packages (from boto3) (0.10.0)\n", "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/conda/lib/python3.7/site-packages (from botocore<1.30.0,>=1.29.77->boto3) (2.8.1)\n", "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /opt/conda/lib/python3.7/site-packages (from botocore<1.30.0,>=1.29.77->boto3) (1.26.12)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.30.0,>=1.29.77->boto3) (1.14.0)\n", "Installing collected packages: botocore, boto3\n", " Attempting uninstall: botocore\n", " Found existing installation: botocore 1.29.76\n", " Uninstalling botocore-1.29.76:\n", " Successfully uninstalled botocore-1.29.76\n", " Attempting uninstall: boto3\n", " Found existing installation: boto3 1.26.76\n", " Uninstalling boto3-1.26.76:\n", " Successfully uninstalled boto3-1.26.76\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "awscli 1.25.63 requires botocore==1.27.62, but you have botocore 1.29.77 which is incompatible.\n", "awscli 1.25.63 requires PyYAML<5.5,>=3.10, but you have pyyaml 6.0 which is incompatible.\n", "awscli 1.25.63 requires rsa<4.8,>=3.1.2, but you have rsa 4.9 which is incompatible.\n", "aiobotocore 1.2.2 requires botocore<1.19.53,>=1.19.52, but you have botocore 1.29.77 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed boto3-1.26.77 botocore-1.29.77\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: botocore in /opt/conda/lib/python3.7/site-packages (1.29.77)\n", "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.7/site-packages (from botocore) (0.10.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /opt/conda/lib/python3.7/site-packages (from botocore) (1.26.12)\n", "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/conda/lib/python3.7/site-packages (from botocore) (2.8.1)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore) (1.14.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install sagemaker --upgrade\n", "%pip install boto3 --upgrade\n", "%pip install botocore --upgrade" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Download the model " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First lets download the model from huggigface so that we cand deploy the model in sagemaker. As mentioned above, we are going to download `bigbird-base-mimic-mortality` model from hugging face hub. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "repository = \"mnaylor/bigbird-base-mimic-mortality\"\n", "model_id = repository.split(\"/\")[-1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To download directly from the hugging face model hub we can use git clone command." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "git: 'lfs' is not a git command. See 'git --help'.\n", "\n", "The most similar command is\n", "\tlog\n", "Cloning into 'bigbird-base-mimic-mortality'...\n", "remote: Enumerating objects: 19, done.\u001b[K\n", "remote: Total 19 (delta 0), reused 0 (delta 0), pack-reused 19\u001b[K\n", "Unpacking objects: 100% (19/19), done.\n" ] } ], "source": [ "!git lfs install\n", "!git clone https://huggingface.co/$repository" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "This will download the model to a directory called `bigbird-base-mimic-mortality`." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "tags": [] }, "outputs": [], "source": [ "assert os.path.exists(\"./{}\".format(model_id)) == True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a custom inference script\n", "\n", "In order to wrap this model and use in the Sagemaker inference we are going to create custome `inference.py`. The Hugging Face Inference Toolkit allows the user to override the default methods of the `HuggingFaceHandlerService`.\n", "\n", "The custom module can override the following methods:\n", "\n", "* `model_fn(model_dir)` overrides the default method for loading a model. \n", "* `input_fn(input_data, content_type)` overrides the default method for pre-processing.\n", "* `predict_fn(processed_data, model)` overrides the default method for predictions.\n", "* `output_fn(prediction, accept)` overrides the default method for post-processing. (We are not overriding this function in this notebook.)\n", "\n", "Following code will create a fodler called `code` and create a custom 'inference.py' script by overriding above methods." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [], "source": [ "!mkdir $model_id/code" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting bigbird-base-mimic-mortality/code/inference.py\n" ] } ], "source": [ "%%writefile $model_id/code/inference.py\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from io import StringIO\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "from transformers import BigBirdTokenizer, BigBirdForSequenceClassification\n", "from typing import Any, Dict, List\n", "import os\n", "import traceback\n", "import json\n", "\n", "MODEL_NAME = \"mnaylor/bigbird-base-mimic-mortality\"\n", "\n", "def model_fn(model_dir: str) -> Dict[str, Any]:\n", " \n", " try :\n", " tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", " model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) \n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model.to(device)\n", " return tokenizer, model\n", " except Exception as e :\n", " print(\"[Custom] Error Occured whicle loading the model.\")\n", " traceback.print_exc()\n", " raise e\n", "\n", "\n", "def predict_fn(input_data: List, torkenizer_model: tuple) -> np.ndarray:\n", " \"\"\"\n", " Apply model to the incoming request\n", " \"\"\"\n", " try :\n", " print(\"[Custom] input data is [{}], [{}]\".format(type(input_data), input_data))\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", " tokenizer, huggingface_model = torkenizer_model\n", " encoded_input = tokenizer(input_data, return_tensors=\"pt\").to(device)\n", "\n", " print(\"[Custom] inputs are [{}]\".format(encoded_input))\n", "\n", " with torch.no_grad():\n", " output = huggingface_model(**encoded_input)\n", " prediction = torch.nn.Softmax(dim=1)(output.logits).detach().cpu().numpy()[:, 1]\n", " print(\"[Custom] output is [{}]\".format(prediction))\n", " return prediction\n", " except Exception as e :\n", " print(\"[Custom] Error Occured while predicting.\")\n", " traceback.print_exc()\n", " raise e\n", " \n", " \n", "def input_fn(request_body: str, request_content_type: str) -> List[str]:\n", " \"\"\"\n", " Deserialize and prepare the prediction input\n", " \"\"\"\n", " try :\n", " print(\"[Custom] Request is [{}] with content type [{}]\".format(request_body, request_content_type))\n", "\n", " if request_content_type == \"text/csv\":\n", " # We have a single column with the text.\n", " sentences = list(pd.read_csv(StringIO(request_body), header=None).values[:, 0].astype(str))\n", " else:\n", " raise ValueError(\"Invalid content type [{}]\".format(request_content_type))\n", " return sentences\n", " except Exception as e :\n", " print(\"[Custom] Error Occured while reading the input.\")\n", " traceback.print_exc()\n", " raise e\n", "\n", "# def output_fn(predictions, accept):\n", "# print(\"[Custom] Prediction output type is [{}] [{}]\".format(accept, predictions))\n", "# #res = predictions.astype(np.uint8)\n", "# res = json.dumps({\"preds\" : predictions.tolist()})\n", "# return res\n" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Also this script requires some addional libraries. We will mention them in the `requirments.txt` file under the `code` directory.\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting bigbird-base-mimic-mortality/code/requirements.txt\n" ] } ], "source": [ "%%writefile $model_id/code/requirements.txt\n", "\n", "pandas\n", "sentencepiece==0.1.97\n", "transformers==4.18.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's verify that both the files exists" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "tags": [] }, "outputs": [], "source": [ "assert os.path.exists(\"./{}/code/inference.py\".format(model_id)) == True\n", "assert os.path.exists(\"./{}/code/requirements.txt\".format(model_id)) == True" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Then lets create a model.tar.gz archive with all the model artifacts and the inference.py script." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "tags": [] }, "outputs": [], "source": [ "import tarfile\n", "file_name = \"hf_model.tar.gz\"\n", "with tarfile.open(file_name, mode=\"w:gz\") as archive:\n", " archive.add(model_id, recursive=True)\n", " \n", "assert os.path.exists(file_name.format(model_id)) == True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploying the Huggingface model on SageMaker\n", "\n", "### Set configurations" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "tags": [] }, "outputs": [], "source": [ "sess = sagemaker.Session()\n", "sess = sagemaker.Session(default_bucket=sess.default_bucket())\n", "sagemaker_client = boto3.client(\"sagemaker\")\n", "region = sess.boto_region_name\n", "bucket = sess.default_bucket()\n", "prefix = \"sagemaker/DEMO-sagemaker-clarify-text\"\n", "\n", "# Define the IAM role\n", "role = sagemaker.get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload the hf_model.tar.gz to S3" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'s3://sagemaker-us-east-1-721929407510/sagemaker/DEMO-sagemaker-clarify-text/hf_model.tar.gz'" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_path_s3 = sess.upload_data(path=\"hf_model.tar.gz\", key_prefix=prefix)\n", "model_path_s3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Can we create a custom HuggingfaceModel class. This class will be used to create model object." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.huggingface import HuggingFaceModel\n", "\n", "# create Hugging Face Model Class\n", "huggingface_model = HuggingFaceModel(\n", " model_data = model_path_s3,\n", " transformers_version='4.6.1',\n", " pytorch_version='1.7.1',\n", " py_version='py36',\n", " role=role,\n", " source_dir = \"./{}/code\".format(model_id),\n", " entry_point = \"inference.py\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the instace type that we are going to deploy this model." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'Image': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04',\n", " 'Environment': {'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code',\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_REGION': 'us-east-1'},\n", " 'ModelDataUrl': 's3://sagemaker-us-east-1-721929407510/huggingface-pytorch-inference-2023-02-23-06-03-45-911/model.tar.gz'}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "instance_type = \"ml.g4dn.xlarge\"\n", "container_def = huggingface_model.prepare_container_def(instance_type=instance_type)\n", "container_def" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create model\n", "\n", "The following parameters are required to create a SageMaker model:\n", "\n", "* `ExecutionRoleArn`: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deployment\n", "\n", "* `ModelName`: name of the SageMaker model.\n", "\n", "* `PrimaryContainer`: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model created: hospital-triage-model\n" ] } ], "source": [ "model_name = \"hospital-triage-model\"\n", "\n", "sagemaker_client.create_model(\n", " ExecutionRoleArn=role,\n", " ModelName=model_name,\n", " PrimaryContainer=container_def,\n", ")\n", "print(f\"Model created: {model_name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create endpoint config\n", "Create an endpoint configuration by calling the create_endpoint_config API. Here, supply the same model_name used in the create_model API call. The create_endpoint_config now supports the additional parameter ClarifyExplainerConfig to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the ShapBaseline parameter) or by a S3 baseline file (the ShapBaselineUri parameter). Please see the developer guide for the optional parameters.\n", "\n", "Here we use a special token as the baseline." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SHAP baseline: [['']]\n" ] } ], "source": [ "baseline = [[\"\"]]\n", "print(f\"SHAP baseline: {baseline}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The TextConfig configured with sentence level granularity (When granularity is sentence, each sentence is a feature, and we need a few sentences per review for good visualization) and the language as English." ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint-config/hospital-triage-model-ep-config',\n", " 'ResponseMetadata': {'RequestId': 'ed9a7555-c241-4105-bf0d-2ce2ff0c8fe8',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'x-amzn-requestid': 'ed9a7555-c241-4105-bf0d-2ce2ff0c8fe8',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '112',\n", " 'date': 'Thu, 23 Feb 2023 06:06:03 GMT'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "endpoint_config_name = \"hospital-triage-model-ep-config\"\n", "csv_serializer = sagemaker.serializers.CSVSerializer()\n", "json_deserializer = sagemaker.deserializers.JSONDeserializer()\n", "\n", "sagemaker_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " \"VariantName\": \"MainVariant\",\n", " \"ModelName\": model_name,\n", " \"InitialInstanceCount\": 1,\n", " \"InstanceType\": instance_type,\n", " }\n", " ],\n", " ExplainerConfig={\n", " \"ClarifyExplainerConfig\": {\n", " \"InferenceConfig\": {\"FeatureTypes\": [\"text\"]},\n", " \"ShapConfig\": {\n", " \"ShapBaselineConfig\": {\"ShapBaseline\": csv_serializer.serialize(baseline)},\n", " \"TextConfig\": {\"Granularity\": \"sentence\", \"Language\": \"en\"},\n", " },\n", " }\n", " },\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create endpoint\n", "Once you have your model and endpoint configuration ready, use the create_endpoint API to create your endpoint. The endpoint_name must be unique within an AWS Region in your AWS account. The create_endpoint API is synchronous in nature and returns an immediate response with the endpoint status being Creating state." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'EndpointArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint/hospital-triage-prediction-endpoint',\n", " 'ResponseMetadata': {'RequestId': '9b20dd03-bb87-4b1b-b4fb-30b9bce57f45',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'x-amzn-requestid': '9b20dd03-bb87-4b1b-b4fb-30b9bce57f45',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '103',\n", " 'date': 'Thu, 23 Feb 2023 06:06:22 GMT'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "endpoint_name = \"hospital-triage-prediction-endpoint\"\n", "sagemaker_client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wait for the endpoint to be in `InService` state\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------!" ] }, { "data": { "text/plain": [ "{'EndpointName': 'hospital-triage-prediction-endpoint',\n", " 'EndpointArn': 'arn:aws:sagemaker:us-east-1:721929407510:endpoint/hospital-triage-prediction-endpoint',\n", " 'EndpointConfigName': 'hospital-triage-model-ep-config',\n", " 'ProductionVariants': [{'VariantName': 'MainVariant',\n", " 'DeployedImages': [{'SpecifiedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04',\n", " 'ResolvedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference@sha256:1d2383a5e52c26db3d2262742d211b979b170fa30a1855302a022e0b9018e6c6',\n", " 'ResolutionTime': datetime.datetime(2023, 2, 23, 6, 6, 23, 451000, tzinfo=tzlocal())}],\n", " 'CurrentWeight': 1.0,\n", " 'DesiredWeight': 1.0,\n", " 'CurrentInstanceCount': 1,\n", " 'DesiredInstanceCount': 1}],\n", " 'EndpointStatus': 'InService',\n", " 'CreationTime': datetime.datetime(2023, 2, 23, 6, 6, 22, 751000, tzinfo=tzlocal()),\n", " 'LastModifiedTime': datetime.datetime(2023, 2, 23, 6, 9, 46, 331000, tzinfo=tzlocal()),\n", " 'ExplainerConfig': {'ClarifyExplainerConfig': {'InferenceConfig': {'FeatureTypes': ['text']},\n", " 'ShapConfig': {'ShapBaselineConfig': {'ShapBaseline': ''},\n", " 'TextConfig': {'Language': 'en', 'Granularity': 'sentence'}}}},\n", " 'ResponseMetadata': {'RequestId': 'fca8e9d9-0adb-4c1a-b995-e8bcab572d52',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'x-amzn-requestid': 'fca8e9d9-0adb-4c1a-b995-e8bcab572d52',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '1024',\n", " 'date': 'Thu, 23 Feb 2023 06:09:55 GMT'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sess.wait_for_endpoint(endpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test endpoint without any explanations\n", "\n", "Use the `EnableExplanations` parameter to disable the explanations for this request.\n", "\n" ] }, { "cell_type": "code", "execution_count": 80, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'Body': ,\n", " 'ContentType': 'application/json',\n", " 'InvokedProductionVariant': 'MainVariant',\n", " 'ResponseMetadata': {'HTTPHeaders': {'content-length': '100',\n", " 'content-type': 'application/json',\n", " 'date': 'Thu, 23 Feb 2023 06:40:22 GMT',\n", " 'x-amzn-invoked-production-variant': 'MainVariant',\n", " 'x-amzn-requestid': '1df6d126-e2c3-4f86-9495-1a494373c5cf'},\n", " 'HTTPStatusCode': 200,\n", " 'RequestId': '1df6d126-e2c3-4f86-9495-1a494373c5cf',\n", " 'RetryAttempts': 0}}\n" ] } ], "source": [ "sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\")\n", "\n", "sample_admission_note = pd.DataFrame([\"\"\"Patient is a 25-year-old male with a chief complaint of acute chest pain. \n", " Patient reports the pain began suddenly while at work and has been constant since. \n", " Patient rates the pain as 8/10 in severity. Patient denies any radiation of pain, shortness of breath, nausea, or vomiting. \n", " Patient reports no previous history of chest pain. \n", " Vital signs are as follows: blood pressure 140/90 mmH. Heart rate 92 beats per minute. \n", " Respiratory rate 18 breaths per minute. Oxygen saturation 96% on room air. \n", " Physical examination reveals mild tenderness to palpation over the precordium and clear lung fields. \n", " EKG shows sinus tachycardia with no ST-elevations or depressions. \"\"\"])\n", "\n", "response = sagemaker_runtime_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"text/csv\",\n", " Accept=\"text/csv\",\n", " Body=csv_serializer.serialize(sample_admission_note.iloc[:1, :].to_numpy()),\n", " EnableExplanations=\"`false`\", # Do not provide explanations\n", ")\n", "\n", "pprint.pprint(response)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'explanations': {},\n", " 'predictions': {'content_type': 'text/csv', 'data': '0.014596084\\n'},\n", " 'version': '1.0'}\n" ] } ], "source": [ "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", "pprint.pprint(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see this is predicted by the model as non-acute case since the probablity is low as `0.015`. But what statements in the admission note used by the model to come for that conclution ? To answer that we cna leverage SageMaker Clarify." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explain the predictions using Amazon Sagemaker Clarify\n", "\n", "There are expanding business, clinical needs, and legislative regulations that require explanations of why a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.\n", "\n", "How does the Kernel SHAP algorithm work? Kernel SHAP algorithm is a local explanation method. That is, it explains each instance or row of the dataset at a time. To explain each instance, it perturbs the features values - that is, it changes the values of some features to a baseline (or non-informative) value, and then get predictions from the model for the perturbed samples. It does this for a number of times per instance (determined by the optional parameter num_samples in SHAPConfig), and computes the importance of each feature based on how the model prediction changed.\n", "\n", "We are now extending this functionality to text data. In order to be able to explain text, we need the TextConfig. The TextConfig is an optional parameter of SHAPConfig, which you need to provide if you need explanations for the text features in your dataset. TextConfig in turn requires three parameters:\n", "\n", "* `granularity` (required): To explain text features, Clarify further breaks down text into smaller text units, and considers each such text unit as a feature. The parameter granularity informs the level to which Clarify will break down the text: token, sentence, or paragraph are the allowed values for granularity.\n", "* `language` (required): the language of the text features. This is required to tokenize the text to break them down to their granular form.\n", "* `max_top_tokens` (optional): the number of top token attributions that will be shown in the output (we need this because the size of vocabulary can be very big). This is an optional parameter, and defaults to 50.\n", "\n", "Kernel SHAP algorithm requires a baseline (also known as background dataset). In case of tabular features, the baseline value/s for a feature is ideally a non-informative or least informative value for that feature. However, for text feature, the baseline values must be the value you want to replace the individual text feature (token, sentence or paragraph) with. For instance, in the example below, we have chosen the baseline values for review_text as , and granularity is sentence. Every time a sentence has to replaced in the perturbed inputs, we will replace it with .\n", "\n", "If baseline is not provided, a baseline is calculated automatically by SageMaker Clarify using K-means or K-prototypes in the input dataset for tabular features. For text features, if baseline is not provided, the default replacement value will be the string ." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test endpoint with explanations\n", "\n", "This time we'll invoke the endpoint with explainations enabaled (which is the default setting). " ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'explanations': {'kernel_shap': [[{'attributions': [{'attribution': [-0.13967000051377493],\n", " 'description': {'partial_text': 'Patient '\n", " 'is '\n", " 'a '\n", " '25-year-old '\n", " 'male '\n", " 'with '\n", " 'a '\n", " 'chief '\n", " 'complaint '\n", " 'of '\n", " 'acute '\n", " 'chest '\n", " 'pain.',\n", " 'start_idx': 0}},\n", " {'attribution': [-0.06911266863252466],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Patient '\n", " 'reports '\n", " 'the '\n", " 'pain '\n", " 'began '\n", " 'suddenly '\n", " 'while '\n", " 'at '\n", " 'work '\n", " 'and '\n", " 'has '\n", " 'been '\n", " 'constant '\n", " 'since.',\n", " 'start_idx': 74}},\n", " {'attribution': [-0.011582978977326262],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Patient '\n", " 'rates '\n", " 'the '\n", " 'pain '\n", " 'as '\n", " '8/10 '\n", " 'in '\n", " 'severity.',\n", " 'start_idx': 162}},\n", " {'attribution': [-0.03850561143025477],\n", " 'description': {'partial_text': 'Patient '\n", " 'denies '\n", " 'any '\n", " 'radiation '\n", " 'of '\n", " 'pain, '\n", " 'shortness '\n", " 'of '\n", " 'breath, '\n", " 'nausea, '\n", " 'or '\n", " 'vomiting.',\n", " 'start_idx': 211}},\n", " {'attribution': [-0.09674479643979414],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Patient '\n", " 'reports '\n", " 'no '\n", " 'previous '\n", " 'history '\n", " 'of '\n", " 'chest '\n", " 'pain.',\n", " 'start_idx': 291}},\n", " {'attribution': [-0.011245212250578111],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Vital '\n", " 'signs '\n", " 'are '\n", " 'as '\n", " 'follows: '\n", " 'blood '\n", " 'pressure '\n", " '140/90 '\n", " 'mmH. '\n", " 'Heart '\n", " 'rate '\n", " '92 '\n", " 'beats '\n", " 'per '\n", " 'minute.',\n", " 'start_idx': 347}},\n", " {'attribution': [0.024608203424975668],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Respiratory '\n", " 'rate '\n", " '18 '\n", " 'breaths '\n", " 'per '\n", " 'minute.',\n", " 'start_idx': 439}},\n", " {'attribution': [-0.01722827160280489],\n", " 'description': {'partial_text': 'Oxygen '\n", " 'saturation '\n", " '96% '\n", " 'on '\n", " 'room '\n", " 'air.',\n", " 'start_idx': 484}},\n", " {'attribution': [0.11806218603963267],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'Physical '\n", " 'examination '\n", " 'reveals '\n", " 'mild '\n", " 'tenderness '\n", " 'to '\n", " 'palpation '\n", " 'over '\n", " 'the '\n", " 'precordium '\n", " 'and '\n", " 'clear '\n", " 'lung '\n", " 'fields.',\n", " 'start_idx': 519}},\n", " {'attribution': [-0.0920242456175506],\n", " 'description': {'partial_text': '\\n'\n", " ' '\n", " 'EKG '\n", " 'shows '\n", " 'sinus '\n", " 'tachycardia '\n", " 'with '\n", " 'no '\n", " 'ST-elevations '\n", " 'or '\n", " 'depressions.',\n", " 'start_idx': 625}}],\n", " 'feature_type': 'text'}]]},\n", " 'predictions': {'content_type': 'text/csv', 'data': '0.014596084\\n'},\n", " 'version': '1.0'}\n" ] } ], "source": [ "\n", "response = sagemaker_runtime_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"text/csv\",\n", " Accept=\"text/csv\",\n", " Body=csv_serializer.serialize(sample_admission_note.iloc[:1, :].to_numpy())\n", ")\n", "\n", "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", "pprint.pprint(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see the `kernel shap` values are returned with the reponse. To interpret this at the sentence level let's use some visualizations.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll define some utility functions to get the visualizations of the SHAP values." ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "tags": [] }, "outputs": [], "source": [ "import csv\n", "import numpy as np\n", "from captum.attr import visualization\n", "\n", "def visualization_record(\n", " attributions, # list of attributions for the tokens\n", " text, # list of tokens\n", " pred, # the prediction value obtained from the endpoint\n", " delta,\n", " true_label, # the true label from the dataset\n", " normalize=True, # normalizes the attributions so that the max absolute value is 1. Yields stronger colors.\n", " max_frac_to_show=0.05, # what fraction of tokens to highlight, set to 1 for all.\n", " match_to_pred=False, # whether to limit highlights to red for negative predictions and green for positive ones.\n", " # By enabling `match_to_pred` you show what tokens contribute to a high/low prediction not those that oppose it.\n", "):\n", " \n", " if normalize:\n", " attributions = attributions / max(max(attributions), max(-attributions))\n", " if max_frac_to_show is not None and max_frac_to_show < 1:\n", " num_show = int(max_frac_to_show * attributions.shape[0])\n", " sal = attributions\n", " if pred < 0.5:\n", " sal = -sal\n", " if not match_to_pred:\n", " sal = np.abs(sal)\n", " top_idxs = np.argsort(-sal)[:num_show]\n", " mask = np.zeros_like(attributions)\n", " mask[top_idxs] = 1\n", " attributions = attributions * mask\n", " return visualization.VisualizationDataRecord(\n", " attributions,\n", " pred,\n", " int(pred > 0.5),\n", " true_label,\n", " attributions.sum() > 0,\n", " attributions.sum(),\n", " text,\n", " delta,\n", " )\n", "\n", "def visualize_result(result, all_labels):\n", " if not result[\"explanations\"]:\n", " print(f\"No Clarify explanations for the record(s)\")\n", " return\n", " all_explanations = result[\"explanations\"][\"kernel_shap\"]\n", " all_predictions = list(csv.reader(result[\"predictions\"][\"data\"].splitlines()))\n", "\n", " labels = []\n", " predictions = []\n", " explanations = []\n", "\n", " for i, expl in enumerate(all_explanations):\n", " if expl:\n", " labels.append(all_labels[i])\n", " predictions.append(all_predictions[i])\n", " explanations.append(all_explanations[i])\n", "\n", " attributions_dataset = [\n", " np.array([attr[\"attribution\"][0] for attr in expl[0][\"attributions\"]])\n", " for expl in explanations\n", " ]\n", " tokens_dataset = [\n", " np.array([attr[\"description\"][\"partial_text\"] for attr in expl[0][\"attributions\"]])\n", " for expl in explanations\n", " ]\n", "\n", " # You can customize the following display settings\n", " normalize = True\n", " max_frac_to_show = 1\n", " match_to_pred = False\n", " vis = []\n", " for attr, token, pred, label in zip(attributions_dataset, tokens_dataset, predictions, labels):\n", " vis.append(\n", " visualization_record(\n", " -attr, token, float(pred[0]), 0.0, label, normalize, max_frac_to_show, match_to_pred\n", " )\n", " )\n", " _ = visualization.visualize_text(vis)\n", "\n", "def invoke_visualize(test_admission_notes, true_label):\n", " response = sagemaker_runtime_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"text/csv\",\n", " Accept=\"text/csv\",\n", " Body=csv_serializer.serialize(test_admission_notes.iloc[:1, :].to_numpy())\n", " )\n", " result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", " visualize_result(result, [true_label])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain the predictions for a non-acute admission note" ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
00 (0.01)True2.39 Patient is a 25-year-old male with a chief complaint of acute chest pain. \n", " Patient reports the pain began suddenly while at work and has been constant since. \n", " Patient rates the pain as 8/10 in severity. Patient denies any radiation of pain, shortness of breath, nausea, or vomiting. \n", " Patient reports no previous history of chest pain. \n", " Vital signs are as follows: blood pressure 140/90 mmH. Heart rate 92 beats per minute. \n", " Respiratory rate 18 breaths per minute. Oxygen saturation 96% on room air. \n", " Physical examination reveals mild tenderness to palpation over the precordium and clear lung fields. \n", " EKG shows sinus tachycardia with no ST-elevations or depressions.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "invoke_visualize(sample_admission_note, 0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain the predictions for a acute admission note" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
[1]1 (0.75)False-1.26 atient is a 72-year-old female with a chief complaint of severe sepsis and septic shock. \n", " Patient reports a fever, chills, and weakness for the past 3 days, as well as decreased urine output and confusion. \n", " Patient has a history of chronic obstructive pulmonary disease (COPD) and a recent hospitalization for pneumonia. \n", " Vital signs are as follows: blood pressure 80/40 mmHg. Heart rate 130 beats per minute. Respiratory rate 30 breaths per minute. \n", " Oxygen saturation 82% on 4L of oxygen via nasal cannula. \n", " Physical examination reveals diffuse erythema and warmth over the lower extremities and positive findings for sepsis such as altered mental status, \n", " tachycardia, and tachypnea. Blood cultures were taken and antibiotic therapy was started with appropriate coverage.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_admission_notes_accute = pd.DataFrame(\n", " [\"\"\"Patient is a 72-year-old female with a chief complaint of severe sepsis and septic shock. \n", " Patient reports a fever, chills, and weakness for the past 3 days, as well as decreased urine output and confusion. \n", " Patient has a history of chronic obstructive pulmonary disease (COPD) and a recent hospitalization for pneumonia. \n", " Vital signs are as follows: blood pressure 80/40 mmHg. Heart rate 130 beats per minute. Respiratory rate 30 breaths per minute. \n", " Oxygen saturation 82% on 4L of oxygen via nasal cannula. \n", " Physical examination reveals diffuse erythema and warmth over the lower extremities and positive findings for sepsis such as altered mental status, \n", " tachycardia, and tachypnea. Blood cultures were taken and antibiotic therapy was started with appropriate coverage.\"\"\"]\n", " )\n", "\n", "invoke_visualize(test_admission_notes_accute, [1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clean Up\n", "\n", "Clean up the deployed models to not incur further charges" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "huggingface_model.delete_model()\n", "predictor = sagemaker.Predictor(endpoint_name=\"triage-prediction-endpoint\")\n", "predictor.delete_endpoint()" ] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }