{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fairness and Explainability with SageMaker Clarify - Spark Distributed Processing"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
"\n",
"\n",
"\n",
"---"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runtime\n",
"\n",
"This notebook takes approximately 30 minutes to run.\n",
"\n",
"## Contents\n",
"1. [Overview](#Overview)\n",
"1. [Prerequisites and Data](#Prerequisites-and-Data)\n",
" 1. [Import Libraries](#Import-Libraries)\n",
" 1. [Set Configurations](#Set-Configurations)\n",
" 1. [Download data](#Download-data)\n",
" 1. [Loading the data: Adult Dataset](#Loading-the-data:-Adult-Dataset) \n",
" 1. [Data inspection](#Data-inspection) \n",
" 1. [Encode and Upload the Dataset](#Encode-and-Upload-the-Dataset) \n",
"1. [Train and Deploy XGBoost Model](#Train-XGBoost-Model)\n",
" 1. [Train Model](#Train-Model)\n",
" 1. [Create Model](#Create-Model)\n",
"1. [Amazon SageMaker Clarify](#Amazon-SageMaker-Clarify)\n",
" 1. [Detecting Bias](#Detecting-Bias)\n",
" 1. [Writing DataConfig](#Writing-DataConfig)\n",
" 1. [Writing ModelConfig](#Writing-ModelConfig)\n",
" 1. [Writing ModelPredictedLabelConfig](#Writing-ModelPredictedLabelConfig)\n",
" 1. [Writing BiasConfig](#Writing-BiasConfig)\n",
" 1. [Pre-training Bias](#Pre-training-Bias)\n",
" 1. [Post-training Bias](#Post-training-Bias)\n",
" 1. [Viewing the Bias Report](#Viewing-the-Bias-Report)\n",
" 1. [Explaining Predictions](#Explaining-Predictions)\n",
" 1. [Viewing the Explainability Report](#Viewing-the-Explainability-Report)\n",
" 1. [Analysis of local explanations](#Analysis-of-local-explanations)\n",
" 1. [Visualize local SHAP values](#Visualize-local-SHAP-values)\n",
"1. [Clean Up](#Clean-Up)\n",
"\n",
"## Overview\n",
"Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. The product comes with the tools to help you with the following tasks.\n",
"\n",
"* Measure biases that can occur during each stage of the ML lifecycle (data collection, model training and tuning, and monitoring of ML models deployed for inference).\n",
"* Generate model governance reports targeting risk and compliance teams and external regulators.\n",
"* Provide explanations of the data, models, and monitoring used to assess predictions.\n",
"\n",
"This sample notebook walks you through: \n",
"1. Key terms and concepts needed to understand SageMaker Clarify\n",
"1. Measuring the pre-training bias of a dataset and post-training bias of a model\n",
"1. Explaining the importance of the various input features on the model's decision\n",
"1. Accessing the reports through SageMaker Studio if you have an instance set up.\n",
"\n",
"In doing so, the notebook first trains a [SageMaker XGBoost](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html) model using training dataset, then use [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) to launch SageMaker Clarify jobs to analyze an example dataset in CSV format. This notebook specifically showcases how to use Spark distributed processing for executing clarify jobs. "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites and Data\n",
"### Import Libraries"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from sagemaker import session, get_execution_role\n",
"from io import StringIO\n",
"from s3fs import S3FileSystem\n",
"import sagemaker\n",
"import json\n",
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import boto3"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set Configurations"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Region: us-west-2\n",
"Role: arn:aws:iam::000000000000:role/service-role/AmazonSageMaker-ExecutionRole-20220304T121686\n"
]
}
],
"source": [
"# Initialize sagemaker session\n",
"sagemaker_session = session.Session()\n",
"\n",
"region = sagemaker_session.boto_region_name\n",
"print(f\"Region: {region}\")\n",
"\n",
"role = get_execution_role()\n",
"print(f\"Role: {role}\")\n",
"\n",
"bucket = sagemaker_session.default_bucket()\n",
"\n",
"prefix = \"sagemaker/DEMO-sagemaker-clarify\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download data\n",
"Data Source: [https://archive.ics.uci.edu/ml/machine-learning-databases/adult/](https://archive.ics.uci.edu/ml/machine-learning-databases/adult/)\n",
"\n",
"Let's __download__ the data and save it in the local folder with the name adult.data and adult.test from UCI repository$^{[2]}$.\n",
"\n",
"$^{[2]}$Dua Dheeru, and Efi Karra Taniskidou. \"[UCI Machine Learning Repository](http://archive.ics.uci.edu/ml)\". Irvine, CA: University of California, School of Information and Computer Science (2017)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"adult.data saved!\n",
"adult.test saved!\n"
]
}
],
"source": [
"from sagemaker.s3 import S3Downloader\n",
"\n",
"adult_columns = [\n",
" \"Age\",\n",
" \"Workclass\",\n",
" \"fnlwgt\",\n",
" \"Education\",\n",
" \"Education-Num\",\n",
" \"Marital Status\",\n",
" \"Occupation\",\n",
" \"Relationship\",\n",
" \"Ethnic group\",\n",
" \"Sex\",\n",
" \"Capital Gain\",\n",
" \"Capital Loss\",\n",
" \"Hours per week\",\n",
" \"Country\",\n",
" \"Target\",\n",
"]\n",
"if not os.path.isfile(\"adult.data\"):\n",
" S3Downloader.download(\n",
" s3_uri=\"s3://{}/{}\".format(\n",
" f\"sagemaker-example-files-prod-{region}\", \"datasets/tabular/uci_adult/adult.data\"\n",
" ),\n",
" local_path=\"./\",\n",
" sagemaker_session=sagemaker_session,\n",
" )\n",
" print(\"adult.data saved!\")\n",
"else:\n",
" print(\"adult.data already on disk.\")\n",
"\n",
"if not os.path.isfile(\"adult.test\"):\n",
" S3Downloader.download(\n",
" s3_uri=\"s3://{}/{}\".format(\n",
" f\"sagemaker-example-files-prod-{region}\", \"datasets/tabular/uci_adult/adult.test\"\n",
" ),\n",
" local_path=\"./\",\n",
" sagemaker_session=sagemaker_session,\n",
" )\n",
" print(\"adult.test saved!\")\n",
"else:\n",
" print(\"adult.test already on disk.\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading the data: Adult Dataset\n",
"From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.\n",
"\n",
"Here are the features and their possible values:\n",
"\n",
"1. **Age**: continuous.\n",
"1. **Workclass**: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.\n",
"1. **Fnlwgt**: continuous (the number of people the census takers believe that observation represents).\n",
"1. **Education**: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.\n",
"1. **Education-num**: continuous.\n",
"1. **Marital-status**: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.\n",
"1. **Occupation**: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.\n",
"1. **Relationship**: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.\n",
"1. **Ethnic group**: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.\n",
"1. **Sex**: Female, Male.\n",
" * **Note**: this data is extracted from the 1994 Census and enforces a binary option on Sex\n",
"1. **Capital-gain**: continuous.\n",
"1. **Capital-loss**: continuous.\n",
"1. **Hours-per-week**: continuous.\n",
"1. **Native-country**: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.\n",
"\n",
"Next, we specify our binary prediction task: \n",
"\n",
"15. **Target**: <=50,000, >$50,000."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
Age
\n",
"
Workclass
\n",
"
fnlwgt
\n",
"
Education
\n",
"
Education-Num
\n",
"
Marital Status
\n",
"
Occupation
\n",
"
Relationship
\n",
"
Ethnic group
\n",
"
Sex
\n",
"
Capital Gain
\n",
"
Capital Loss
\n",
"
Hours per week
\n",
"
Country
\n",
"
Target
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
39
\n",
"
State-gov
\n",
"
77516
\n",
"
Bachelors
\n",
"
13
\n",
"
Never-married
\n",
"
Adm-clerical
\n",
"
Not-in-family
\n",
"
White
\n",
"
Male
\n",
"
2174
\n",
"
0
\n",
"
40
\n",
"
United-States
\n",
"
<=50K
\n",
"
\n",
"
\n",
"
1
\n",
"
50
\n",
"
Self-emp-not-inc
\n",
"
83311
\n",
"
Bachelors
\n",
"
13
\n",
"
Married-civ-spouse
\n",
"
Exec-managerial
\n",
"
Husband
\n",
"
White
\n",
"
Male
\n",
"
0
\n",
"
0
\n",
"
13
\n",
"
United-States
\n",
"
<=50K
\n",
"
\n",
"
\n",
"
2
\n",
"
38
\n",
"
Private
\n",
"
215646
\n",
"
HS-grad
\n",
"
9
\n",
"
Divorced
\n",
"
Handlers-cleaners
\n",
"
Not-in-family
\n",
"
White
\n",
"
Male
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
United-States
\n",
"
<=50K
\n",
"
\n",
"
\n",
"
3
\n",
"
53
\n",
"
Private
\n",
"
234721
\n",
"
11th
\n",
"
7
\n",
"
Married-civ-spouse
\n",
"
Handlers-cleaners
\n",
"
Husband
\n",
"
Black
\n",
"
Male
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
United-States
\n",
"
<=50K
\n",
"
\n",
"
\n",
"
4
\n",
"
28
\n",
"
Private
\n",
"
338409
\n",
"
Bachelors
\n",
"
13
\n",
"
Married-civ-spouse
\n",
"
Prof-specialty
\n",
"
Wife
\n",
"
Black
\n",
"
Female
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
Cuba
\n",
"
<=50K
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Age Workclass fnlwgt Education Education-Num \\\n",
"0 39 State-gov 77516 Bachelors 13 \n",
"1 50 Self-emp-not-inc 83311 Bachelors 13 \n",
"2 38 Private 215646 HS-grad 9 \n",
"3 53 Private 234721 11th 7 \n",
"4 28 Private 338409 Bachelors 13 \n",
"\n",
" Marital Status Occupation Relationship Ethnic group Sex \\\n",
"0 Never-married Adm-clerical Not-in-family White Male \n",
"1 Married-civ-spouse Exec-managerial Husband White Male \n",
"2 Divorced Handlers-cleaners Not-in-family White Male \n",
"3 Married-civ-spouse Handlers-cleaners Husband Black Male \n",
"4 Married-civ-spouse Prof-specialty Wife Black Female \n",
"\n",
" Capital Gain Capital Loss Hours per week Country Target \n",
"0 2174 0 40 United-States <=50K \n",
"1 0 0 13 United-States <=50K \n",
"2 0 0 40 United-States <=50K \n",
"3 0 0 40 United-States <=50K \n",
"4 0 0 40 Cuba <=50K "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_data = pd.read_csv(\n",
" \"adult.data\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\"\n",
").dropna()\n",
"\n",
"testing_data = pd.read_csv(\n",
" \"adult.test\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\", skiprows=1\n",
").dropna()\n",
"\n",
"training_data.head()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data inspection\n",
"Plotting histograms for the distribution of the different features is a good way to visualize the data. Let's plot a few of the features that can be considered _sensitive_. \n",
"Let's take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~$\\frac{1}{7}$th of respondents."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"training_data[\"Sex\"].where(training_data[\"Target\"] == \">50K\").value_counts().sort_values().plot(\n",
" kind=\"bar\", title=\"Counts of Sex earning >$50K\", rot=0\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Encode and Upload the Dataset\n",
"Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:8: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" \n"
]
}
],
"source": [
"from sklearn import preprocessing\n",
"\n",
"\n",
"def number_encode_features(df):\n",
" result = df.copy()\n",
" encoders = {}\n",
" for column in result.columns:\n",
" if result.dtypes[column] == np.object:\n",
" encoders[column] = preprocessing.LabelEncoder()\n",
" # print('Column:', column, result[column])\n",
" result[column] = encoders[column].fit_transform(result[column].fillna(\"None\"))\n",
" return result, encoders\n",
"\n",
"\n",
"training_data = pd.concat([training_data[\"Target\"], training_data.drop([\"Target\"], axis=1)], axis=1)\n",
"training_data, _ = number_encode_features(training_data)\n",
"training_data.to_csv(\"train_data.csv\", index=False, header=False)\n",
"\n",
"testing_data, _ = number_encode_features(testing_data)\n",
"test_features = testing_data.drop([\"Target\"], axis=1)\n",
"test_target = testing_data[\"Target\"]\n",
"test_features.to_csv(\"test_features.csv\", index=False, header=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"A quick note about our encoding: the \"Female\" Sex value has been encoded as 0 and \"Male\" as 1."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
Target
\n",
"
Age
\n",
"
Workclass
\n",
"
fnlwgt
\n",
"
Education
\n",
"
Education-Num
\n",
"
Marital Status
\n",
"
Occupation
\n",
"
Relationship
\n",
"
Ethnic group
\n",
"
Sex
\n",
"
Capital Gain
\n",
"
Capital Loss
\n",
"
Hours per week
\n",
"
Country
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
0
\n",
"
39
\n",
"
5
\n",
"
77516
\n",
"
9
\n",
"
13
\n",
"
4
\n",
"
0
\n",
"
1
\n",
"
4
\n",
"
1
\n",
"
2174
\n",
"
0
\n",
"
40
\n",
"
38
\n",
"
\n",
"
\n",
"
1
\n",
"
0
\n",
"
50
\n",
"
4
\n",
"
83311
\n",
"
9
\n",
"
13
\n",
"
2
\n",
"
3
\n",
"
0
\n",
"
4
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
13
\n",
"
38
\n",
"
\n",
"
\n",
"
2
\n",
"
0
\n",
"
38
\n",
"
2
\n",
"
215646
\n",
"
11
\n",
"
9
\n",
"
0
\n",
"
5
\n",
"
1
\n",
"
4
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
38
\n",
"
\n",
"
\n",
"
3
\n",
"
0
\n",
"
53
\n",
"
2
\n",
"
234721
\n",
"
1
\n",
"
7
\n",
"
2
\n",
"
5
\n",
"
0
\n",
"
2
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
38
\n",
"
\n",
"
\n",
"
4
\n",
"
0
\n",
"
28
\n",
"
2
\n",
"
338409
\n",
"
9
\n",
"
13
\n",
"
2
\n",
"
9
\n",
"
5
\n",
"
2
\n",
"
0
\n",
"
0
\n",
"
0
\n",
"
40
\n",
"
4
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Target Age Workclass fnlwgt Education Education-Num Marital Status \\\n",
"0 0 39 5 77516 9 13 4 \n",
"1 0 50 4 83311 9 13 2 \n",
"2 0 38 2 215646 11 9 0 \n",
"3 0 53 2 234721 1 7 2 \n",
"4 0 28 2 338409 9 13 2 \n",
"\n",
" Occupation Relationship Ethnic group Sex Capital Gain Capital Loss \\\n",
"0 0 1 4 1 2174 0 \n",
"1 3 0 4 1 0 0 \n",
"2 5 1 4 1 0 0 \n",
"3 5 0 2 1 0 0 \n",
"4 9 5 2 0 0 0 \n",
"\n",
" Hours per week Country \n",
"0 40 38 \n",
"1 13 38 \n",
"2 40 38 \n",
"3 40 38 \n",
"4 40 4 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_data.head()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Lastly, let's upload the data to S3"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from sagemaker.s3 import S3Uploader\n",
"from sagemaker.inputs import TrainingInput\n",
"\n",
"train_uri = S3Uploader.upload(\"train_data.csv\", \"s3://{}/{}\".format(bucket, prefix))\n",
"train_input = TrainingInput(train_uri, content_type=\"csv\")\n",
"test_uri = S3Uploader.upload(\"test_features.csv\", \"s3://{}/{}\".format(bucket, prefix))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train XGBoost Model\n",
"#### Train Model\n",
"Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:sagemaker:Creating training-job with name: sagemaker-xgboost-2023-02-07-03-49-15-216\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"2023-02-07 03:49:15 Starting - Starting the training job..\n",
"2023-02-07 03:49:29 Starting - Preparing the instances for training..........\n",
"2023-02-07 03:50:24 Downloading - Downloading input data....\n",
"2023-02-07 03:50:49 Training - Downloading the training image.....\n",
"2023-02-07 03:51:19 Training - Training image download completed. Training in progress......\n",
"2023-02-07 03:51:50 Uploading - Uploading generated training model.\n",
"2023-02-07 03:52:01 Completed - Training job completed\n"
]
}
],
"source": [
"from sagemaker.image_uris import retrieve\n",
"from sagemaker.estimator import Estimator\n",
"\n",
"container = retrieve(\"xgboost\", region, version=\"1.2-1\")\n",
"xgb = Estimator(\n",
" container,\n",
" role,\n",
" instance_count=1,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" disable_profiler=True,\n",
" sagemaker_session=sagemaker_session,\n",
")\n",
"\n",
"xgb.set_hyperparameters(\n",
" max_depth=5,\n",
" eta=0.2,\n",
" gamma=4,\n",
" min_child_weight=6,\n",
" subsample=0.8,\n",
" objective=\"binary:logistic\",\n",
" num_round=800,\n",
")\n",
"\n",
"xgb.fit({\"train\": train_input}, logs=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Deploy Model\n",
"Here we create the SageMaker model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:sagemaker:Creating model with name: DEMO-clarify-model-07-02-2023-03-52-02\n"
]
},
{
"data": {
"text/plain": [
"'DEMO-clarify-model-07-02-2023-03-52-02'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datetime import datetime\n",
"\n",
"model_name = \"DEMO-clarify-model-{}\".format(datetime.now().strftime(\"%d-%m-%Y-%H-%M-%S\"))\n",
"model = xgb.create_model(name=model_name)\n",
"container_def = model.prepare_container_def()\n",
"sagemaker_session.create_model(model_name, role, container_def)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Amazon SageMaker Clarify\n",
"With your model set up, it's time to explore SageMaker Clarify. For a general overview of how SageMaker Clarify processing jobs work, refer [the provided link](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-processing-job-configure-how-it-works.html).\n",
"\n",
"When working with large datasets, you can use the Spark processing capabilities of SageMaker Clarify to enable your Clarify processing jobs to run faster. To use Spark processing for Clarify jobs, set the instance count to a number greater than one. Clarify uses Spark distributed computing when there is more than one instance per Clarify processor."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:sagemaker.image_uris:Defaulting to the only supported framework/algorithm version: 1.0.\n",
"INFO:sagemaker.image_uris:Ignoring unnecessary instance type: None.\n"
]
}
],
"source": [
"from sagemaker import clarify\n",
"\n",
"# Initialize a SageMakerClarifyProcessor to compute bias metrics and model explanations with instance_count > 1\n",
"clarify_processor = clarify.SageMakerClarifyProcessor(\n",
" role=role, instance_count=2, instance_type=\"ml.m5.xlarge\", sagemaker_session=sagemaker_session\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Detecting Bias\n",
"SageMaker Clarify helps you detect possible [pre-training](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-detect-data-bias.html) and [post-training](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-detect-post-training-bias.html) biases using a variety of metrics.\n",
"\n",
"#### Writing DataConfig\n",
"A [DataConfig](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.DataConfig) object communicates some basic information about data I/O to SageMaker Clarify. For our example here we provide the below information:\n",
"\n",
"* `s3_data_input_path`: S3 URI of the train dataset we uploaded above\n",
"* `s3_output_path`: S3 URI at which our output report will be uploaded\n",
"* `label`: Specifies the ground truth label, which is also known as observed label or target attribute. It is used for many bias metrics. In this example, the `Target` column has the ground truth label.\n",
"* `headers`: The list of column names in the dataset\n",
"* `dataset_type`: specifies the format of your dataset, for this example as we are using CSV dataset this will be `text/csv`"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"bias_report_output_path = \"s3://{}/{}/clarify-bias\".format(bucket, prefix)\n",
"bias_data_config = clarify.DataConfig(\n",
" s3_data_input_path=train_uri,\n",
" s3_output_path=bias_report_output_path,\n",
" label=\"Target\",\n",
" headers=training_data.columns.to_list(),\n",
" dataset_type=\"text/csv\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Writing ModelConfig\n",
"A [ModelConfig](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.ModelConfig) object communicates information about your trained model. To avoid additional traffic to the production models, SageMaker Clarify sets up and tears down a dedicated endpoint when processing. For our example here we provide the below information:\n",
"\n",
"* `model_name`: name of the concerned model, using name of the xgboost model trained earlier\n",
"* `instance_type` and `initial_instance_count` specify your preferred instance type and instance count used to run your model on during SageMaker Clarify's processing. Since we used two instances for the ClarifyProcessingJob, we recommend that you also increase the number of instances in the model configuration. This is to prevent the processing instances from being bottle necked by the shadow endpoint.\n",
"* `accept_type` denotes the endpoint response payload format, and `content_type` denotes the payload format of request to the endpoint. As per the example model we created above both of these will be `text/csv`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_config = clarify.ModelConfig(\n",
" model_name=model_name,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" instance_count=2,\n",
" accept_type=\"text/csv\",\n",
" content_type=\"text/csv\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Writing ModelPredictedLabelConfig\n",
"A [ModelPredictedLabelConfig](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.ModelPredictedLabelConfig) provides information on the format of your predictions. XGBoost model outputs probabilities of samples, so SageMaker Clarify invokes the endpoint then uses `probability_threshold` to convert the probability to binary labels for bias analysis. Prediction above the threshold is interpreted as label value `1` and below or equal as label value `0`."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"predictions_config = clarify.ModelPredictedLabelConfig(probability_threshold=0.8)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Writing BiasConfig\n",
"[BiasConfig](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.BiasConfig) contains configuration values for detecting bias using a Clarify container."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"bias_config = clarify.BiasConfig(\n",
" label_values_or_threshold=[1], facet_name=\"Sex\", facet_values_or_threshold=[0], group_name=\"Age\"\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"For our demo we provide the following information in BiasConfig API:\n",
"\n",
"* `label_values_or_threshold`: List of label value(s) or threshold to indicate positive outcome used for bias metrics. Here positive outcome is earning >$50,000.\n",
"* `facet_name`: Sensitive columns of the dataset, \"Sex\" is the category\n",
"* `facet_values_or_threshold`: values of the sensitive group, \"Female\" respondents are the sensitive group.\n",
"* `group_name`: This example has selected the \"Age\" column which is used to form subgroups for the measurement of bias metric [Conditional Demographic Disparity (CDD)](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html) or [Conditional Demographic Disparity in Predicted Labels (CDDPL)](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html).\n",
"\n",
"SageMaker Clarify can handle both categorical and continuous data for `facet: values_or_threshold` and for `label_values_or_threshold`. In this case we are using categorical data. The results will show if the model has a preference for records of one sex over the other."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Pre-training Bias\n",
"Bias can be present in your data before any model training occurs. Inspecting your data for bias before training begins can help detect any data collection gaps, inform your feature engineering, and help you understand what societal biases the data may reflect.\n",
"\n",
"Computing pre-training bias metrics does not require a trained model.\n",
"\n",
"#### Post-training Bias\n",
"Computing post-training bias metrics does require a trained model.\n",
"\n",
"Unbiased training data (as determined by concepts of fairness measured by bias metric) may still result in biased model predictions after training. Whether this occurs depends on several factors including hyperparameter choices.\n",
"\n",
"\n",
"You can run these options separately with `run_pre_training_bias()` and `run_post_training_bias()` or at the same time with `run_bias()` as shown below. We use following additional parameters for the api call:\n",
"\n",
"* `pre_training_methods`: Pre-training bias metrics to be computed. The detailed description of the metrics can be found on [Measure Pre-training Bias](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-measure-data-bias.html). This example sets methods to \"all\" to compute all the pre-training bias metrics.\n",
"* `post_training_methods`: Post-training bias metrics to be computed. The detailed description of the metrics can be found on [Measure Post-training Bias](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-detect-post-training-bias.html). This example sets methods to \"all\" to compute all the post-training bias metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# The job takes about 10 minutes to run\n",
"clarify_processor.run_bias(\n",
" data_config=bias_data_config,\n",
" bias_config=bias_config,\n",
" model_config=model_config,\n",
" model_predicted_label_config=predictions_config,\n",
" pre_training_methods=\"all\",\n",
" post_training_methods=\"all\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Viewing the Bias Report\n",
"In Studio, you can view the results under the experiments tab.\n",
"\n",
"\n",
"\n",
"Each bias metric has detailed explanations with examples that you can explore.\n",
"\n",
"\n",
"\n",
"You could also summarize the results in a handy table!\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're not a Studio user yet, you can access the bias report in PDF, HTML and ipynb formats in the following S3 bucket:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'s3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-bias'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bias_report_output_path"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explaining Predictions\n",
"There are expanding business needs and legislative regulations that require explanations of _why_ a model made the decision it did. SageMaker Clarify uses Kernel SHAP to explain the contribution that each input feature makes to the final decision.\n",
"\n",
"For run_explainability API call we need similar `DataConfig` and `ModelConfig` objects we defined above. [SHAPConfig](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.SHAPConfig) here is the config class for Kernel SHAP algorithm.\n",
"\n",
"For our demo we pass the following information in `SHAPConfig`:\n",
"\n",
"* `baseline`: Kernel SHAP algorithm requires a baseline (also known as background dataset). If not provided, a baseline is calculated automatically by SageMaker Clarify using K-means or K-prototypes in the input dataset. Baseline dataset type shall be the same as dataset_type, and baseline samples shall only include features. By definition, baseline should either be a S3 URI to the baseline dataset file, or an in-place list of samples. In this case we chose the latter, and put the mean of the train dataset to the list. For more details on baseline selection please [refer this documentation](https://docs.aws.amazon.com/en_us/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html).\n",
"* `num_samples`: Number of samples to be used in the Kernel SHAP algorithm. This number determines the size of the generated synthetic dataset to compute the SHAP values. \n",
"* `agg_method`: Aggregation method for global SHAP values. For our example here we are using `mean_abs` i.e. mean of absolute SHAP values for all instances\n",
"* `save_local_shap_values`: Indicates whether to save the local SHAP values in the output location. Default is True."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"explainability_output_path = \"s3://{}/{}/clarify-explainability\".format(bucket, prefix)\n",
"explainability_data_config = clarify.DataConfig(\n",
" s3_data_input_path=train_uri,\n",
" s3_output_path=explainability_output_path,\n",
" label=\"Target\",\n",
" headers=training_data.columns.to_list(),\n",
" dataset_type=\"text/csv\",\n",
")\n",
"\n",
"baseline = [training_data.mean().iloc[1:].values.tolist()]\n",
"shap_config = clarify.SHAPConfig(\n",
" baseline=baseline,\n",
" num_samples=15,\n",
" agg_method=\"mean_abs\",\n",
" save_local_shap_values=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# The job takes about 10 minutes to run\n",
"clarify_processor.run_explainability(\n",
" data_config=explainability_data_config,\n",
" model_config=model_config,\n",
" explainability_config=shap_config,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Viewing the Explainability Report\n",
"As with the bias report, you can view the explainability report in Studio under the experiments tab\n",
"\n",
"\n",
"\n",
"\n",
"The Model Insights tab contains direct links to the report and model insights.\n",
"\n",
"If you're not a Studio user yet, as with the Bias Report, you can access this report at the following S3 bucket."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'s3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-explainability'"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explainability_output_path"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Analysis of local explanations\n",
"It is possible to visualize the local explanations for single examples in your dataset. You can use the obtained results from running Kernel SHAP algorithm for global explanations."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Age', 'Capital Gain', 'Capital Loss', 'Country', 'Education',\n",
" 'Education-Num', 'Ethnic group', 'Hours per week', 'Marital Status',\n",
" 'Occupation', 'Relationship', 'Sex', 'Workclass', 'fnlwgt'],\n",
" dtype='object')"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"analysis_result_json = sagemaker.s3.S3Downloader.read_file(\n",
" explainability_output_path + \"/analysis.json\"\n",
")\n",
"analysis_result = json.loads(analysis_result_json)\n",
"shap_values = analysis_result[\"explanations\"][\"kernel_shap\"][\"label0\"][\"global_shap_values\"]\n",
"features = pd.Series(shap_values)\n",
"feature_names = features.index\n",
"feature_names"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With Clarify Spark jobs, the output files that contain local SHAP values will be split into multiple files. You will need to collate them before you can visualize them."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 128 files in S3\n"
]
}
],
"source": [
"_s3 = boto3.resource(\"s3\")\n",
"my_bucket = _s3.Bucket(bucket)\n",
"s3_files = [\n",
" \"s3://{}/{}\".format(obj.bucket_name, obj.key)\n",
" for obj in my_bucket.objects.filter(\n",
" Prefix=prefix + \"/clarify-explainability/explanations_shap/out.csv/\"\n",
" )\n",
" if obj.key.endswith(\".csv\")\n",
"]\n",
"print(f\"Found {len(s3_files)} files in S3\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# For the sake of time, open a subset of the s3 files\n",
"num_files_to_open = len(s3_files)\n",
"local_shap_values = pd.DataFrame()\n",
"for file in s3_files[:num_files_to_open]:\n",
" output = sagemaker.s3.S3Downloader.read_file(file)\n",
" df = pd.read_csv(StringIO(output), sep=\",\")\n",
" local_shap_values = local_shap_values.append(df, ignore_index=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Visualize local SHAP values"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"max_features_to_display = 15\n",
"feature_names = local_shap_values.columns\n",
"\n",
"fig = plt.figure(figsize=(max_features_to_display, max_features_to_display))\n",
"low = local_shap_values.min().min()\n",
"high = local_shap_values.max().max()\n",
"\n",
"i = 1\n",
"for feature_name in feature_names:\n",
" plt.subplot(max_features_to_display, 1, i)\n",
" shap_value = local_shap_values[f\"{feature_name}\"].to_frame()\n",
" feature = pd.Series([feature_name] * shap_value.shape[0]).to_frame()\n",
" df = pd.concat([shap_value, feature], axis=1, join=\"inner\", ignore_index=True)\n",
" df.columns = [\"shap_value\", \"feature\"]\n",
" num_rows_to_display = min(df.shape[0], 500)\n",
" df = df.sample(num_rows_to_display)\n",
" ax = sns.violinplot(\n",
" y=\"feature\",\n",
" x=\"shap_value\",\n",
" data=df,\n",
" size=6,\n",
" color=\"#f5f5f5\",\n",
" inner=\"quartile\",\n",
" bw=0.2,\n",
" cut=0,\n",
" orient=\"h\",\n",
" )\n",
" ax.set_xlim(low, high)\n",
" sns.stripplot(\n",
" y=\"feature\",\n",
" x=\"shap_value\",\n",
" data=df,\n",
" size=4,\n",
" orient=\"h\",\n",
" )\n",
" ax.vlines(0, -1, 1, color=\"#ff0000\", linewidth=2)\n",
" ax.set_ylabel(\"\")\n",
" ax.legend([], [], frameon=False)\n",
" i += 1\n",
"\n",
"\n",
"plt.xlabel(\"Local SHAP Values\", fontsize=14)\n",
"plt.tight_layout()\n",
"plt.subplots_adjust(hspace=0, wspace=0.1)\n",
"\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** You can run both bias and explainability jobs at the same time with `run_bias_and_explainability()`, refer [API Documentation](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.SageMakerClarifyProcessor.run_bias_and_explainability) for more details."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clean Up\n",
"Finally, don't forget to clean up the resources we set up and used for this demo!"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:sagemaker:Deleting model with name: DEMO-clarify-model-07-02-2023-03-52-02\n"
]
}
],
"source": [
"sagemaker_session.delete_model(model_name)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notebook CI Test Results\n",
"\n",
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (Data Science 3.0)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}