{
"cells": [
{
"cell_type": "markdown",
"id": "083ded7d",
"metadata": {
"papermill": {
"duration": 0.020102,
"end_time": "2022-04-18T15:42:07.746324",
"exception": false,
"start_time": "2022-04-18T15:42:07.726222",
"status": "completed"
},
"tags": []
},
"source": [
"# Fairness and Explainability with SageMaker Clarify"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
"\n",
"\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "869afe6f",
"metadata": {
"papermill": {
"duration": 0.0202,
"end_time": "2022-04-18T15:42:07.786675",
"exception": false,
"start_time": "2022-04-18T15:42:07.766475",
"status": "completed"
},
"tags": []
},
"source": [
"## Runtime\n",
"\n",
"This notebook takes approximately 30 minutes to run.\n",
"\n",
"## Contents\n",
"\n",
"1. [Overview](#Overview)\n",
"1. [Prerequisites and Data](#Prerequisites-and-Data)\n",
" 1. [Initialize SageMaker](#Initialize-SageMaker)\n",
" 1. [Download data](#Download-data)\n",
" 1. [Loading the data: Adult Dataset](#Loading-the-data:-Adult-Dataset) \n",
" 1. [Data inspection](#Data-inspection) \n",
" 1. [Data encoding and upload to S3](#Encode-and-Upload-the-Data) \n",
"1. [Train and Deploy XGBoost Model](#Train-XGBoost-Model)\n",
" 1. [Train Model](#Train-Model)\n",
" 1. [Create Model](#Create-Model)\n",
"1. [Amazon SageMaker Clarify](#Amazon-SageMaker-Clarify)\n",
" 1. [Detecting Bias](#Detecting-Bias)\n",
" 1. [Writing BiasConfig](#Writing-BiasConfig)\n",
" 1. [Pre-training Bias](#Pre-training-Bias)\n",
" 1. [Post-training Bias](#Post-training-Bias)\n",
" 1. [Viewing the Bias Report](#Viewing-the-Bias-Report)\n",
" 1. [Explaining Predictions](#Explaining-Predictions)\n",
" 1. [Viewing the Explainability Report](#Viewing-the-Explainability-Report)\n",
"1. [Clean Up](#Clean-Up)\n",
"\n",
"## Overview\n",
"Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. The product comes with the tools to help you with the following tasks.\n",
"\n",
"* Measure biases that can occur during each stage of the ML lifecycle (data collection, model training and tuning, and monitoring of ML models deployed for inference).\n",
"* Generate model governance reports targeting risk and compliance teams and external regulators.\n",
"* Provide explanations of the data, models, and monitoring used to assess predictions.\n",
"\n",
"This sample notebook walks you through: \n",
"1. Key terms and concepts needed to understand SageMaker Clarify\n",
"1. Measuring the pre-training bias of a dataset and post-training bias of a model\n",
"1. Explaining the importance of the various input features on the model's decision\n",
"1. Accessing the reports through SageMaker Studio if you have an instance set up.\n",
"\n",
"In doing so, the notebook first trains a [SageMaker XGBoost](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html) model using training dataset, then use SageMaker Clarify to analyze a testing dataset in CSV format. SageMaker Clarify also supports analyzing dataset in [SageMaker JSON Lines dense format](https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#common-in-formats), which is illustrated in [another notebook](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-clarify/fairness_and_explainability/fairness_and_explainability_jsonlines_format.ipynb)."
]
},
{
"cell_type": "markdown",
"id": "d009cb6d",
"metadata": {
"papermill": {
"duration": 0.020058,
"end_time": "2022-04-18T15:42:07.826883",
"exception": false,
"start_time": "2022-04-18T15:42:07.806825",
"status": "completed"
},
"tags": []
},
"source": [
"## Prerequisites and Data\n",
"### Initialize SageMaker"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d05d4578",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:07.873260Z",
"iopub.status.busy": "2022-04-18T15:42:07.872366Z",
"iopub.status.idle": "2022-04-18T15:42:09.367100Z",
"shell.execute_reply": "2022-04-18T15:42:09.366670Z"
},
"papermill": {
"duration": 1.520269,
"end_time": "2022-04-18T15:42:09.367251",
"exception": false,
"start_time": "2022-04-18T15:42:07.846982",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from sagemaker import Session\n",
"\n",
"session = Session()\n",
"bucket = session.default_bucket()\n",
"prefix = \"sagemaker/DEMO-sagemaker-clarify\"\n",
"region = session.boto_region_name\n",
"# Define IAM role\n",
"from sagemaker import get_execution_role\n",
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"import boto3\n",
"from datetime import datetime\n",
"\n",
"role = get_execution_role()\n",
"s3_client = boto3.client(\"s3\")"
]
},
{
"cell_type": "markdown",
"id": "0372a3cf",
"metadata": {
"papermill": {
"duration": 0.021199,
"end_time": "2022-04-18T15:42:09.408780",
"exception": false,
"start_time": "2022-04-18T15:42:09.387581",
"status": "completed"
},
"tags": []
},
"source": [
"### Download data\n",
"Data Source: [https://archive.ics.uci.edu/ml/machine-learning-databases/adult/](https://archive.ics.uci.edu/ml/machine-learning-databases/adult/)\n",
"\n",
"Let's __download__ the data and save it in the local folder with the name adult.data and adult.test from UCI repository$^{[2]}$.\n",
"\n",
"$^{[2]}$Dua Dheeru, and Efi Karra Taniskidou. \"[UCI Machine Learning Repository](http://archive.ics.uci.edu/ml)\". Irvine, CA: University of California, School of Information and Computer Science (2017)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fc24db78",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:09.459240Z",
"iopub.status.busy": "2022-04-18T15:42:09.458522Z",
"iopub.status.idle": "2022-04-18T15:42:10.838357Z",
"shell.execute_reply": "2022-04-18T15:42:10.837796Z"
},
"papermill": {
"duration": 1.409563,
"end_time": "2022-04-18T15:42:10.838505",
"exception": false,
"start_time": "2022-04-18T15:42:09.428942",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"adult.data saved!\n",
"adult.test saved!\n"
]
}
],
"source": [
"adult_columns = [\n",
" \"Age\",\n",
" \"Workclass\",\n",
" \"fnlwgt\",\n",
" \"Education\",\n",
" \"Education-Num\",\n",
" \"Marital Status\",\n",
" \"Occupation\",\n",
" \"Relationship\",\n",
" \"Ethnic group\",\n",
" \"Sex\",\n",
" \"Capital Gain\",\n",
" \"Capital Loss\",\n",
" \"Hours per week\",\n",
" \"Country\",\n",
" \"Target\",\n",
"]\n",
"if not os.path.isfile(\"adult.data\"):\n",
" s3_client.download_file(\n",
" \"sagemaker-sample-files\", \"datasets/tabular/uci_adult/adult.data\", \"adult.data\"\n",
" )\n",
" print(\"adult.data saved!\")\n",
"else:\n",
" print(\"adult.data already on disk.\")\n",
"\n",
"if not os.path.isfile(\"adult.test\"):\n",
" s3_client.download_file(\n",
" \"sagemaker-sample-files\", \"datasets/tabular/uci_adult/adult.test\", \"adult.test\"\n",
" )\n",
" print(\"adult.test saved!\")\n",
"else:\n",
" print(\"adult.test already on disk.\")"
]
},
{
"cell_type": "markdown",
"id": "b00899c5",
"metadata": {
"papermill": {
"duration": 0.02317,
"end_time": "2022-04-18T15:42:10.885485",
"exception": false,
"start_time": "2022-04-18T15:42:10.862315",
"status": "completed"
},
"tags": []
},
"source": [
"### Loading the data: Adult Dataset\n",
"From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.\n",
"\n",
"Here are the features and their possible values:\n",
"1. **Age**: continuous.\n",
"1. **Workclass**: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.\n",
"1. **Fnlwgt**: continuous (the number of people the census takers believe that observation represents).\n",
"1. **Education**: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.\n",
"1. **Education-num**: continuous.\n",
"1. **Marital-status**: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.\n",
"1. **Occupation**: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.\n",
"1. **Relationship**: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.\n",
"1. **Ethnic group**: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.\n",
"1. **Sex**: Female, Male.\n",
" * **Note**: this data is extracted from the 1994 Census and enforces a binary option on Sex\n",
"1. **Capital-gain**: continuous.\n",
"1. **Capital-loss**: continuous.\n",
"1. **Hours-per-week**: continuous.\n",
"1. **Native-country**: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.\n",
"\n",
"Next, we specify our binary prediction task: \n",
"15. **Target**: <=50,000, >$50,000."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "be1e2a1f",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:10.934980Z",
"iopub.status.busy": "2022-04-18T15:42:10.934029Z",
"iopub.status.idle": "2022-04-18T15:42:11.981113Z",
"shell.execute_reply": "2022-04-18T15:42:11.982318Z"
},
"papermill": {
"duration": 1.076158,
"end_time": "2022-04-18T15:42:11.982501",
"exception": false,
"start_time": "2022-04-18T15:42:10.906343",
"status": "completed"
},
"scrolled": true,
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Workclass \n",
" fnlwgt \n",
" Education \n",
" Education-Num \n",
" Marital Status \n",
" Occupation \n",
" Relationship \n",
" Ethnic group \n",
" Sex \n",
" Capital Gain \n",
" Capital Loss \n",
" Hours per week \n",
" Country \n",
" Target \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 39 \n",
" State-gov \n",
" 77516 \n",
" Bachelors \n",
" 13 \n",
" Never-married \n",
" Adm-clerical \n",
" Not-in-family \n",
" White \n",
" Male \n",
" 2174 \n",
" 0 \n",
" 40 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 1 \n",
" 50 \n",
" Self-emp-not-inc \n",
" 83311 \n",
" Bachelors \n",
" 13 \n",
" Married-civ-spouse \n",
" Exec-managerial \n",
" Husband \n",
" White \n",
" Male \n",
" 0 \n",
" 0 \n",
" 13 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 2 \n",
" 38 \n",
" Private \n",
" 215646 \n",
" HS-grad \n",
" 9 \n",
" Divorced \n",
" Handlers-cleaners \n",
" Not-in-family \n",
" White \n",
" Male \n",
" 0 \n",
" 0 \n",
" 40 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 3 \n",
" 53 \n",
" Private \n",
" 234721 \n",
" 11th \n",
" 7 \n",
" Married-civ-spouse \n",
" Handlers-cleaners \n",
" Husband \n",
" Black \n",
" Male \n",
" 0 \n",
" 0 \n",
" 40 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 4 \n",
" 28 \n",
" Private \n",
" 338409 \n",
" Bachelors \n",
" 13 \n",
" Married-civ-spouse \n",
" Prof-specialty \n",
" Wife \n",
" Black \n",
" Female \n",
" 0 \n",
" 0 \n",
" 40 \n",
" Cuba \n",
" <=50K \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Age Workclass fnlwgt Education Education-Num \\\n",
"0 39 State-gov 77516 Bachelors 13 \n",
"1 50 Self-emp-not-inc 83311 Bachelors 13 \n",
"2 38 Private 215646 HS-grad 9 \n",
"3 53 Private 234721 11th 7 \n",
"4 28 Private 338409 Bachelors 13 \n",
"\n",
" Marital Status Occupation Relationship Ethnic group Sex \\\n",
"0 Never-married Adm-clerical Not-in-family White Male \n",
"1 Married-civ-spouse Exec-managerial Husband White Male \n",
"2 Divorced Handlers-cleaners Not-in-family White Male \n",
"3 Married-civ-spouse Handlers-cleaners Husband Black Male \n",
"4 Married-civ-spouse Prof-specialty Wife Black Female \n",
"\n",
" Capital Gain Capital Loss Hours per week Country Target \n",
"0 2174 0 40 United-States <=50K \n",
"1 0 0 13 United-States <=50K \n",
"2 0 0 40 United-States <=50K \n",
"3 0 0 40 United-States <=50K \n",
"4 0 0 40 Cuba <=50K "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_data = pd.read_csv(\n",
" \"adult.data\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\"\n",
").dropna()\n",
"\n",
"testing_data = pd.read_csv(\n",
" \"adult.test\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\", skiprows=1\n",
").dropna()\n",
"\n",
"training_data.head()"
]
},
{
"cell_type": "markdown",
"id": "3b11ce50",
"metadata": {
"papermill": {
"duration": 0.029795,
"end_time": "2022-04-18T15:42:12.041291",
"exception": false,
"start_time": "2022-04-18T15:42:12.011496",
"status": "completed"
},
"tags": []
},
"source": [
"### Data inspection\n",
"Plotting histograms for the distribution of the different features is a good way to visualize the data. Let's plot a few of the features that can be considered _sensitive_. \n",
"Let's take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~$\\frac{1}{7}$th of respondents."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9050f337",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:12.107265Z",
"iopub.status.busy": "2022-04-18T15:42:12.106139Z",
"iopub.status.idle": "2022-04-18T15:42:14.162289Z",
"shell.execute_reply": "2022-04-18T15:42:14.162672Z"
},
"papermill": {
"duration": 2.09361,
"end_time": "2022-04-18T15:42:14.162800",
"exception": false,
"start_time": "2022-04-18T15:42:12.069190",
"status": "completed"
},
"scrolled": true,
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"training_data[\"Sex\"].value_counts().sort_values().plot(kind=\"bar\", title=\"Counts of Sex\", rot=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e7151483",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:14.279169Z",
"iopub.status.busy": "2022-04-18T15:42:14.277102Z",
"iopub.status.idle": "2022-04-18T15:42:14.584591Z",
"shell.execute_reply": "2022-04-18T15:42:14.584989Z"
},
"papermill": {
"duration": 0.399927,
"end_time": "2022-04-18T15:42:14.585118",
"exception": false,
"start_time": "2022-04-18T15:42:14.185191",
"status": "completed"
},
"scrolled": true,
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"$50K'}>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"training_data[\"Sex\"].where(training_data[\"Target\"] == \">50K\").value_counts().sort_values().plot(\n",
" kind=\"bar\", title=\"Counts of Sex earning >$50K\", rot=0\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c4a4afec",
"metadata": {
"papermill": {
"duration": 0.086259,
"end_time": "2022-04-18T15:42:14.767130",
"exception": false,
"start_time": "2022-04-18T15:42:14.680871",
"status": "completed"
},
"tags": []
},
"source": [
"### Encode and Upload the Dataset\n",
"Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c54fd61c",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:14.886149Z",
"iopub.status.busy": "2022-04-18T15:42:14.885365Z",
"iopub.status.idle": "2022-04-18T15:42:16.872703Z",
"shell.execute_reply": "2022-04-18T15:42:16.872261Z"
},
"papermill": {
"duration": 2.082627,
"end_time": "2022-04-18T15:42:16.872819",
"exception": false,
"start_time": "2022-04-18T15:42:14.790192",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from sklearn import preprocessing\n",
"\n",
"\n",
"def number_encode_features(df):\n",
" result = df.copy()\n",
" encoders = {}\n",
" for column in result.columns:\n",
" if result.dtypes[column] == np.object:\n",
" encoders[column] = preprocessing.LabelEncoder()\n",
" result[column] = encoders[column].fit_transform(result[column].fillna(\"None\"))\n",
" return result, encoders\n",
"\n",
"\n",
"training_data = pd.concat([training_data[\"Target\"], training_data.drop([\"Target\"], axis=1)], axis=1)\n",
"training_data, _ = number_encode_features(training_data)\n",
"training_data.to_csv(\"train_data.csv\", index=False, header=False)\n",
"\n",
"testing_data, _ = number_encode_features(testing_data)\n",
"test_features = testing_data.drop([\"Target\"], axis=1)\n",
"test_target = testing_data[\"Target\"]\n",
"test_features.to_csv(\"test_features.csv\", index=False, header=False)"
]
},
{
"cell_type": "markdown",
"id": "0dff253a",
"metadata": {
"papermill": {
"duration": 0.085611,
"end_time": "2022-04-18T15:42:16.982827",
"exception": false,
"start_time": "2022-04-18T15:42:16.897216",
"status": "completed"
},
"tags": []
},
"source": [
"A quick note about our encoding: the \"Female\" Sex value has been encoded as 0 and \"Male\" as 1."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "53706464",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:17.176444Z",
"iopub.status.busy": "2022-04-18T15:42:17.175481Z",
"iopub.status.idle": "2022-04-18T15:42:17.179230Z",
"shell.execute_reply": "2022-04-18T15:42:17.179614Z"
},
"papermill": {
"duration": 0.111035,
"end_time": "2022-04-18T15:42:17.179749",
"exception": false,
"start_time": "2022-04-18T15:42:17.068714",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Target \n",
" Age \n",
" Workclass \n",
" fnlwgt \n",
" Education \n",
" Education-Num \n",
" Marital Status \n",
" Occupation \n",
" Relationship \n",
" Ethnic group \n",
" Sex \n",
" Capital Gain \n",
" Capital Loss \n",
" Hours per week \n",
" Country \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0 \n",
" 39 \n",
" 5 \n",
" 77516 \n",
" 9 \n",
" 13 \n",
" 4 \n",
" 0 \n",
" 1 \n",
" 4 \n",
" 1 \n",
" 2174 \n",
" 0 \n",
" 40 \n",
" 38 \n",
" \n",
" \n",
" 1 \n",
" 0 \n",
" 50 \n",
" 4 \n",
" 83311 \n",
" 9 \n",
" 13 \n",
" 2 \n",
" 3 \n",
" 0 \n",
" 4 \n",
" 1 \n",
" 0 \n",
" 0 \n",
" 13 \n",
" 38 \n",
" \n",
" \n",
" 2 \n",
" 0 \n",
" 38 \n",
" 2 \n",
" 215646 \n",
" 11 \n",
" 9 \n",
" 0 \n",
" 5 \n",
" 1 \n",
" 4 \n",
" 1 \n",
" 0 \n",
" 0 \n",
" 40 \n",
" 38 \n",
" \n",
" \n",
" 3 \n",
" 0 \n",
" 53 \n",
" 2 \n",
" 234721 \n",
" 1 \n",
" 7 \n",
" 2 \n",
" 5 \n",
" 0 \n",
" 2 \n",
" 1 \n",
" 0 \n",
" 0 \n",
" 40 \n",
" 38 \n",
" \n",
" \n",
" 4 \n",
" 0 \n",
" 28 \n",
" 2 \n",
" 338409 \n",
" 9 \n",
" 13 \n",
" 2 \n",
" 9 \n",
" 5 \n",
" 2 \n",
" 0 \n",
" 0 \n",
" 0 \n",
" 40 \n",
" 4 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Target Age Workclass fnlwgt Education Education-Num Marital Status \\\n",
"0 0 39 5 77516 9 13 4 \n",
"1 0 50 4 83311 9 13 2 \n",
"2 0 38 2 215646 11 9 0 \n",
"3 0 53 2 234721 1 7 2 \n",
"4 0 28 2 338409 9 13 2 \n",
"\n",
" Occupation Relationship Ethnic group Sex Capital Gain Capital Loss \\\n",
"0 0 1 4 1 2174 0 \n",
"1 3 0 4 1 0 0 \n",
"2 5 1 4 1 0 0 \n",
"3 5 0 2 1 0 0 \n",
"4 9 5 2 0 0 0 \n",
"\n",
" Hours per week Country \n",
"0 40 38 \n",
"1 13 38 \n",
"2 40 38 \n",
"3 40 38 \n",
"4 40 4 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_data.head()"
]
},
{
"cell_type": "markdown",
"id": "fa10543a",
"metadata": {
"papermill": {
"duration": 0.024025,
"end_time": "2022-04-18T15:42:17.290367",
"exception": false,
"start_time": "2022-04-18T15:42:17.266342",
"status": "completed"
},
"tags": []
},
"source": [
"Lastly, let's upload the data to S3."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c1e828e1",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:17.776575Z",
"iopub.status.busy": "2022-04-18T15:42:17.472573Z",
"iopub.status.idle": "2022-04-18T15:42:18.190540Z",
"shell.execute_reply": "2022-04-18T15:42:18.190923Z"
},
"papermill": {
"duration": 0.813762,
"end_time": "2022-04-18T15:42:18.191062",
"exception": false,
"start_time": "2022-04-18T15:42:17.377300",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from sagemaker.s3 import S3Uploader\n",
"from sagemaker.inputs import TrainingInput\n",
"\n",
"train_uri = S3Uploader.upload(\"train_data.csv\", \"s3://{}/{}\".format(bucket, prefix))\n",
"train_input = TrainingInput(train_uri, content_type=\"csv\")\n",
"test_uri = S3Uploader.upload(\"test_features.csv\", \"s3://{}/{}\".format(bucket, prefix))"
]
},
{
"cell_type": "markdown",
"id": "d9e9b77c",
"metadata": {
"papermill": {
"duration": 0.02369,
"end_time": "2022-04-18T15:42:18.238778",
"exception": false,
"start_time": "2022-04-18T15:42:18.215088",
"status": "completed"
},
"tags": []
},
"source": [
"### Train XGBoost Model\n",
"#### Train Model\n",
"Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fcab8273",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:42:18.371152Z",
"iopub.status.busy": "2022-04-18T15:42:18.370192Z",
"iopub.status.idle": "2022-04-18T15:45:44.661620Z",
"shell.execute_reply": "2022-04-18T15:45:44.662200Z"
},
"papermill": {
"duration": 206.396586,
"end_time": "2022-04-18T15:45:44.662348",
"exception": false,
"start_time": "2022-04-18T15:42:18.265762",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"2022-04-18 15:42:18 Starting - Starting the training job...\n",
"2022-04-18 15:42:34 Starting - Preparing the instances for training............\n",
"2022-04-18 15:43:42 Downloading - Downloading input data...\n",
"2022-04-18 15:44:02 Training - Downloading the training image...........\n",
"2022-04-18 15:45:03 Training - Training image download completed. Training in progress......\n",
"2022-04-18 15:45:34 Uploading - Uploading generated training model.\n",
"2022-04-18 15:45:40 Completed - Training job completed\n"
]
}
],
"source": [
"from sagemaker.image_uris import retrieve\n",
"from sagemaker.estimator import Estimator\n",
"\n",
"container = retrieve(\"xgboost\", region, version=\"1.2-1\")\n",
"xgb = Estimator(\n",
" container,\n",
" role,\n",
" instance_count=1,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" disable_profiler=True,\n",
" sagemaker_session=session,\n",
")\n",
"\n",
"xgb.set_hyperparameters(\n",
" max_depth=5,\n",
" eta=0.2,\n",
" gamma=4,\n",
" min_child_weight=6,\n",
" subsample=0.8,\n",
" objective=\"binary:logistic\",\n",
" num_round=800,\n",
")\n",
"\n",
"xgb.fit({\"train\": train_input}, logs=False)"
]
},
{
"cell_type": "markdown",
"id": "5d67780c",
"metadata": {
"papermill": {
"duration": 0.044056,
"end_time": "2022-04-18T15:45:44.748690",
"exception": false,
"start_time": "2022-04-18T15:45:44.704634",
"status": "completed"
},
"tags": []
},
"source": [
"#### Create Model\n",
"Here we create the SageMaker model."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f39807a0",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:44.844076Z",
"iopub.status.busy": "2022-04-18T15:45:44.843305Z",
"iopub.status.idle": "2022-04-18T15:45:45.241770Z",
"shell.execute_reply": "2022-04-18T15:45:45.242483Z"
},
"papermill": {
"duration": 0.445712,
"end_time": "2022-04-18T15:45:45.242628",
"exception": false,
"start_time": "2022-04-18T15:45:44.796916",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'DEMO-clarify-model-18-04-2022-15-45-44'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_name = \"DEMO-clarify-model-{}\".format(datetime.now().strftime(\"%d-%m-%Y-%H-%M-%S\"))\n",
"model = xgb.create_model(name=model_name)\n",
"container_def = model.prepare_container_def()\n",
"session.create_model(model_name, role, container_def)"
]
},
{
"cell_type": "markdown",
"id": "efc6e484",
"metadata": {
"papermill": {
"duration": 0.035834,
"end_time": "2022-04-18T15:45:45.317431",
"exception": false,
"start_time": "2022-04-18T15:45:45.281597",
"status": "completed"
},
"tags": []
},
"source": [
"## Amazon SageMaker Clarify\n",
"Now that you have your model set up, let's say hello to SageMaker Clarify!"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7d9f5fab",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:45.412724Z",
"iopub.status.busy": "2022-04-18T15:45:45.394186Z",
"iopub.status.idle": "2022-04-18T15:45:45.416063Z",
"shell.execute_reply": "2022-04-18T15:45:45.415616Z"
},
"papermill": {
"duration": 0.062687,
"end_time": "2022-04-18T15:45:45.416177",
"exception": false,
"start_time": "2022-04-18T15:45:45.353490",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from sagemaker import clarify\n",
"\n",
"clarify_processor = clarify.SageMakerClarifyProcessor(\n",
" role=role, instance_count=1, instance_type=\"ml.m5.xlarge\", sagemaker_session=session\n",
")"
]
},
{
"cell_type": "markdown",
"id": "dce26f7f",
"metadata": {
"papermill": {
"duration": 0.035006,
"end_time": "2022-04-18T15:45:45.488228",
"exception": false,
"start_time": "2022-04-18T15:45:45.453222",
"status": "completed"
},
"tags": []
},
"source": [
"### Detecting Bias\n",
"SageMaker Clarify helps you detect possible pre- and post-training biases using a variety of metrics.\n",
"#### Writing DataConfig and ModelConfig\n",
"A `DataConfig` object communicates some basic information about data I/O to SageMaker Clarify. We specify where to find the input dataset, where to store the output, the target column (`label`), the header names, and the dataset type."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "47bf995f",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:45.570752Z",
"iopub.status.busy": "2022-04-18T15:45:45.569644Z",
"iopub.status.idle": "2022-04-18T15:45:45.572869Z",
"shell.execute_reply": "2022-04-18T15:45:45.572368Z"
},
"papermill": {
"duration": 0.043789,
"end_time": "2022-04-18T15:45:45.572994",
"exception": false,
"start_time": "2022-04-18T15:45:45.529205",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"bias_report_output_path = \"s3://{}/{}/clarify-bias\".format(bucket, prefix)\n",
"bias_data_config = clarify.DataConfig(\n",
" s3_data_input_path=train_uri,\n",
" s3_output_path=bias_report_output_path,\n",
" label=\"Target\",\n",
" headers=training_data.columns.to_list(),\n",
" dataset_type=\"text/csv\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "533631dd",
"metadata": {
"papermill": {
"duration": 0.040195,
"end_time": "2022-04-18T15:45:45.654040",
"exception": false,
"start_time": "2022-04-18T15:45:45.613845",
"status": "completed"
},
"tags": []
},
"source": [
"A `ModelConfig` object communicates information about your trained model. To avoid additional traffic to your production models, SageMaker Clarify sets up and tears down a dedicated endpoint when processing.\n",
"* `instance_type` and `instance_count` specify your preferred instance type and instance count used to run your model on during SageMaker Clarify's processing. The testing dataset is small so a single standard instance is good enough to run this example. If your have a large complex dataset, you may want to use a better instance type to speed up, or add more instances to enable Spark parallelization.\n",
"* `accept_type` denotes the endpoint response payload format, and `content_type` denotes the payload format of request to the endpoint."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a8971b78",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:45.730684Z",
"iopub.status.busy": "2022-04-18T15:45:45.729722Z",
"iopub.status.idle": "2022-04-18T15:45:45.732326Z",
"shell.execute_reply": "2022-04-18T15:45:45.732966Z"
},
"papermill": {
"duration": 0.043473,
"end_time": "2022-04-18T15:45:45.733104",
"exception": false,
"start_time": "2022-04-18T15:45:45.689631",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"model_config = clarify.ModelConfig(\n",
" model_name=model_name,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" instance_count=1,\n",
" accept_type=\"text/csv\",\n",
" content_type=\"text/csv\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "29559ca9",
"metadata": {
"papermill": {
"duration": 0.040697,
"end_time": "2022-04-18T15:45:45.809694",
"exception": false,
"start_time": "2022-04-18T15:45:45.768997",
"status": "completed"
},
"tags": []
},
"source": [
"A `ModelPredictedLabelConfig` provides information on the format of your predictions. XGBoost model outputs probabilities of samples, so SageMaker Clarify invokes the endpoint then uses `probability_threshold` to convert the probability to binary labels for bias analysis. Prediction above the threshold is interpreted as label value `1` and below or equal as label value `0`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1b8c33af",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:45.887354Z",
"iopub.status.busy": "2022-04-18T15:45:45.886505Z",
"iopub.status.idle": "2022-04-18T15:45:45.888916Z",
"shell.execute_reply": "2022-04-18T15:45:45.888503Z"
},
"papermill": {
"duration": 0.040649,
"end_time": "2022-04-18T15:45:45.889059",
"exception": false,
"start_time": "2022-04-18T15:45:45.848410",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"predictions_config = clarify.ModelPredictedLabelConfig(probability_threshold=0.8)"
]
},
{
"cell_type": "markdown",
"id": "23672b46",
"metadata": {
"papermill": {
"duration": 0.034177,
"end_time": "2022-04-18T15:45:45.957127",
"exception": false,
"start_time": "2022-04-18T15:45:45.922950",
"status": "completed"
},
"tags": []
},
"source": [
"#### Writing BiasConfig\n",
"SageMaker Clarify also needs information on what the sensitive columns (`facets`) are, what the sensitive features (`facet_values_or_threshold`) may be, and what the desirable outcomes are (`label_values_or_threshold`).\n",
"SageMaker Clarify can handle both categorical and continuous data for `facet_values_or_threshold` and for `label_values_or_threshold`. In this case we are using categorical data.\n",
"\n",
"We specify this information in the `BiasConfig` API. Here that the positive outcome is earning >$50,000, Sex is a sensitive category, and Female respondents are the sensitive group. `group_name` is used to form subgroups for the measurement of Conditional Demographic Disparity in Labels (CDDL) and Conditional Demographic Disparity in Predicted Labels (CDDPL) with regards to Simpson\u2019s paradox."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "95207e36",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:46.030157Z",
"iopub.status.busy": "2022-04-18T15:45:46.029301Z",
"iopub.status.idle": "2022-04-18T15:45:46.032167Z",
"shell.execute_reply": "2022-04-18T15:45:46.031691Z"
},
"papermill": {
"duration": 0.04116,
"end_time": "2022-04-18T15:45:46.032286",
"exception": false,
"start_time": "2022-04-18T15:45:45.991126",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"bias_config = clarify.BiasConfig(\n",
" label_values_or_threshold=[1], facet_name=\"Sex\", facet_values_or_threshold=[0], group_name=\"Age\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c5e5b7be",
"metadata": {
"papermill": {
"duration": 0.034256,
"end_time": "2022-04-18T15:45:46.100764",
"exception": false,
"start_time": "2022-04-18T15:45:46.066508",
"status": "completed"
},
"tags": []
},
"source": [
"#### Pre-training Bias\n",
"Bias can be present in your data before any model training occurs. Inspecting your data for bias before training begins can help detect any data collection gaps, inform your feature engineering, and help you understand what societal biases the data may reflect.\n",
"\n",
"Computing pre-training bias metrics does not require a trained model.\n",
"\n",
"#### Post-training Bias\n",
"Computing post-training bias metrics does require a trained model.\n",
"\n",
"Unbiased training data (as determined by concepts of fairness measured by bias metric) may still result in biased model predictions after training. Whether this occurs depends on several factors including hyperparameter choices.\n",
"\n",
"\n",
"You can run these options separately with `run_pre_training_bias()` and `run_post_training_bias()` or at the same time with `run_bias()` as shown below."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f7001f07",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:45:46.174561Z",
"iopub.status.busy": "2022-04-18T15:45:46.173716Z",
"iopub.status.idle": "2022-04-18T15:54:02.714315Z",
"shell.execute_reply": "2022-04-18T15:54:02.714703Z"
},
"papermill": {
"duration": 496.580275,
"end_time": "2022-04-18T15:54:02.714839",
"exception": false,
"start_time": "2022-04-18T15:45:46.134564",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Job Name: Clarify-Bias-2022-04-18-15-45-46-171\n",
"Inputs: [{'InputName': 'dataset', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/train_data.csv', 'LocalPath': '/opt/ml/processing/input/data', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}, {'InputName': 'analysis_config', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-bias/analysis_config.json', 'LocalPath': '/opt/ml/processing/input/config', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}]\n",
"Outputs: [{'OutputName': 'analysis_result', 'AppManaged': False, 'S3Output': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-bias', 'LocalPath': '/opt/ml/processing/output', 'S3UploadMode': 'EndOfJob'}}]\n",
".............................\u001b[34m2022-04-18 15:50:24,047 logging.conf not found when configuring logging, using default logging configuration.\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,047 Starting SageMaker Clarify Processing job\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,048 Analysis config path: /opt/ml/processing/input/config/analysis_config.json\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,048 Analysis result path: /opt/ml/processing/output\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,048 This host is algo-1.\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,048 This host is the leader.\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,048 Number of hosts in the cluster is 1.\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,194 Running Python / Pandas based analyzer.\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,195 Dataset type: text/csv uri: /opt/ml/processing/input/data\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,205 Loading dataset...\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,239 Loaded dataset. Dataset info:\u001b[0m\n",
"\u001b[34m\u001b[0m\n",
"\u001b[34mRangeIndex: 30162 entries, 0 to 30161\u001b[0m\n",
"\u001b[34mData columns (total 14 columns):\n",
" # Column Non-Null Count Dtype\u001b[0m\n",
"\u001b[34m--- ------ -------------- -----\n",
" 0 Age 30162 non-null int64\n",
" 1 Workclass 30162 non-null int64\n",
" 2 fnlwgt 30162 non-null int64\n",
" 3 Education 30162 non-null int64\n",
" 4 Education-Num 30162 non-null int64\n",
" 5 Marital Status 30162 non-null int64\n",
" 6 Occupation 30162 non-null int64\n",
" 7 Relationship 30162 non-null int64\n",
" 8 Ethnic group 30162 non-null int64\n",
" 9 Sex 30162 non-null int64\n",
" 10 Capital Gain 30162 non-null int64\n",
" 11 Capital Loss 30162 non-null int64\n",
" 12 Hours per week 30162 non-null int64\n",
" 13 Country 30162 non-null int64\u001b[0m\n",
"\u001b[34mdtypes: int64(14)\u001b[0m\n",
"\u001b[34mmemory usage: 3.2 MB\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,377 Spinning up shadow endpoint\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,377 Creating endpoint-config with name sm-clarify-config-1650297024-8acd\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,486 Creating endpoint: 'sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297024-12c3'\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,766 Using endpoint name: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297024-12c3\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,767 Waiting for endpoint ...\u001b[0m\n",
"\u001b[34m2022-04-18 15:50:24,767 Checking endpoint status:\u001b[0m\n",
"\u001b[34mLegend:\u001b[0m\n",
"\u001b[34m(OutOfService: x, Creating: -, Updating: -, InService: !, RollingBack: <, Deleting: o, Failed: *)\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 Endpoint is in service after 180 seconds\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 Endpoint ready.\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 ======================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 Calculating post-training bias metrics\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 ======================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:25,164 Getting predictions from the endpoint\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:27,354 We assume a prediction above 0.800 indicates 1 and below or equal indicates 0.\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:27,355 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:27,357 Column Sex with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:27,360 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:27,362 Column None with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,676 Stop using endpoint: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297024-12c3\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,676 Deleting endpoint configuration with name: sm-clarify-config-1650297024-8acd\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,745 Deleting endpoint with name: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297024-12c3\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,834 Model endpoint delivered 0.56949 requests per second and a total of 2 requests over 4 seconds\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,835 Stop using endpoint: None\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,835 =====================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,835 Calculating pre-training bias metrics\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,835 =====================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,835 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,838 Column Sex with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:28,840 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,145 ======================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,145 Calculating bias statistics for report\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,145 ======================================\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,146 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,149 Column Sex with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,151 Column Target with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,153 Column None with data uniqueness fraction 6.630860022544923e-05 is classifed as a CATEGORICAL column\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,161 Stop using endpoint: None\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:29,430 jupyter nbconvert --to html --output /opt/ml/processing/output/report.html /opt/ml/processing/output/report.ipynb --template sagemaker-xai\u001b[0m\n",
"\u001b[34m[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html\u001b[0m\n",
"\u001b[34m[NbConvertApp] Writing 344696 bytes to /opt/ml/processing/output/report.html\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:30,320 HTML report '/opt/ml/processing/output/report.html' generated successfully.\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:30,320 wkhtmltopdf -q /opt/ml/processing/output/report.html /opt/ml/processing/output/report.pdf\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:30,813 PDF report '/opt/ml/processing/output/report.pdf' generated successfully.\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:30,813 Collected analyses: \u001b[0m\n",
"\u001b[34m{\n",
" \"version\": \"1.0\",\n",
" \"post_training_bias_metrics\": {\n",
" \"label\": \"Target\",\n",
" \"facets\": {\n",
" \"Sex\": [\n",
" {\n",
" \"value_or_threshold\": \"0\",\n",
" \"metrics\": [\n",
" {\n",
" \"name\": \"AD\",\n",
" \"description\": \"Accuracy Difference (AD)\",\n",
" \"value\": -0.1141572442143538\n",
" },\n",
" {\n",
" \"name\": \"CDDPL\",\n",
" \"description\": \"Conditional Demographic Disparity in Predicted Labels (CDDPL)\",\n",
" \"value\": 0.19716203919079375\n",
" },\n",
" {\n",
" \"name\": \"DAR\",\n",
" \"description\": \"Difference in Acceptance Rates (DAR)\",\n",
" \"value\": -0.007405223292617502\n",
" },\n",
" {\n",
" \"name\": \"DCA\",\n",
" \"description\": \"Difference in Conditional Acceptance (DCA)\",\n",
" \"value\": -0.22750276729134145\n",
" },\n",
" {\n",
" \"name\": \"DCR\",\n",
" \"description\": \"Difference in Conditional Rejection (DCR)\",\n",
" \"value\": 0.13282504190308553\n",
" },\n",
" {\n",
" \"name\": \"DI\",\n",
" \"description\": \"Disparate Impact (DI)\",\n",
" \"value\": 0.32939129409419415\n",
" },\n",
" {\n",
" \"name\": \"DPPL\",\n",
" \"description\": \"Difference in Positive Proportions in Predicted Labels (DPPL)\",\n",
" \"value\": 0.0922004707530946\n",
" },\n",
" {\n",
" \"name\": \"DRR\",\n",
" \"description\": \"Difference in Rejection Rates (DRR)\",\n",
" \"value\": 0.13653296409568605\n",
" },\n",
" {\n",
" \"name\": \"FT\",\n",
" \"description\": \"Flip Test (FT)\",\n",
" \"value\": -0.004600286240032713\n",
" },\n",
" {\n",
" \"name\": \"RD\",\n",
" \"description\": \"Recall Difference (RD)\",\n",
" \"value\": 0.03556460647616988\n",
" },\n",
" {\n",
" \"name\": \"TE\",\n",
" \"description\": \"Treatment Equality (TE)\",\n",
" \"value\": 24.11428571428572\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
" },\n",
" \"label_value_or_threshold\": \"1\"\n",
" },\n",
" \"pre_training_bias_metrics\": {\n",
" \"label\": \"Target\",\n",
" \"facets\": {\n",
" \"Sex\": [\n",
" {\n",
" \"value_or_threshold\": \"0\",\n",
" \"metrics\": [\n",
" {\n",
" \"name\": \"CDDL\",\n",
" \"description\": \"Conditional Demographic Disparity in Labels (CDDL)\",\n",
" \"value\": 0.214915908649356\n",
" },\n",
" {\n",
" \"name\": \"CI\",\n",
" \"description\": \"Class Imbalance (CI)\",\n",
" \"value\": 0.3513692725946555\n",
" },\n",
" {\n",
" \"name\": \"DPL\",\n",
" \"description\": \"Difference in Positive Proportions in Labels (DPL)\",\n",
" \"value\": 0.20015891077100018\n",
" },\n",
" {\n",
" \"name\": \"JS\",\n",
" \"description\": \"Jensen-Shannon Divergence (JS)\",\n",
" \"value\": 0.030756144659773006\n",
" },\n",
" {\n",
" \"name\": \"KL\",\n",
" \"description\": \"Kullback-Liebler Divergence (KL)\",\n",
" \"value\": 0.14306865156306434\n",
" },\n",
" {\n",
" \"name\": \"KS\",\n",
" \"description\": \"Kolmogorov-Smirnov Distance (KS)\",\n",
" \"value\": 0.20015891077100018\n",
" },\n",
" {\n",
" \"name\": \"LP\",\n",
" \"description\": \"L-p Norm (LP)\",\n",
" \"value\": 0.2830674462421746\n",
" },\n",
" {\n",
" \"name\": \"TVD\",\n",
" \"description\": \"Total Variation Distance (TVD)\",\n",
" \"value\": 0.20015891077100015\n",
" }\n",
" ]\n",
" }\n",
" ]\n",
" },\n",
" \"label_value_or_threshold\": \"1\"\n",
" }\u001b[0m\n",
"\u001b[34m}\u001b[0m\n",
"\u001b[34m2022-04-18 15:53:30,814 exit_message: Completed: SageMaker XAI Analyzer ran successfully\u001b[0m\n",
"\u001b[34m---!\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"clarify_processor.run_bias(\n",
" data_config=bias_data_config,\n",
" bias_config=bias_config,\n",
" model_config=model_config,\n",
" model_predicted_label_config=predictions_config,\n",
" pre_training_methods=\"all\",\n",
" post_training_methods=\"all\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f7cda707",
"metadata": {
"papermill": {
"duration": 0.048381,
"end_time": "2022-04-18T15:54:02.817876",
"exception": false,
"start_time": "2022-04-18T15:54:02.769495",
"status": "completed"
},
"tags": []
},
"source": [
"#### Viewing the Bias Report\n",
"In Studio, you can view the results under the experiments tab.\n",
"\n",
" \n",
"\n",
"Each bias metric has detailed explanations with examples that you can explore.\n",
"\n",
" \n",
"\n",
"You could also summarize the results in a handy table!\n",
"\n",
" \n"
]
},
{
"cell_type": "markdown",
"id": "3caf52ba",
"metadata": {
"papermill": {
"duration": 0.048103,
"end_time": "2022-04-18T15:54:02.915109",
"exception": false,
"start_time": "2022-04-18T15:54:02.867006",
"status": "completed"
},
"tags": []
},
"source": [
"If you're not a Studio user yet, you can access the bias report in pdf, html and ipynb formats in the following S3 bucket:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "9847314b",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:54:03.117134Z",
"iopub.status.busy": "2022-04-18T15:54:03.110947Z",
"iopub.status.idle": "2022-04-18T15:54:03.124304Z",
"shell.execute_reply": "2022-04-18T15:54:03.121910Z"
},
"papermill": {
"duration": 0.128454,
"end_time": "2022-04-18T15:54:03.124461",
"exception": false,
"start_time": "2022-04-18T15:54:02.996007",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'s3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-bias'"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bias_report_output_path"
]
},
{
"cell_type": "markdown",
"id": "cd6fa1c1",
"metadata": {
"papermill": {
"duration": 0.0441,
"end_time": "2022-04-18T15:54:03.260887",
"exception": false,
"start_time": "2022-04-18T15:54:03.216787",
"status": "completed"
},
"tags": []
},
"source": [
"### Explaining Predictions\n",
"There are expanding business needs and legislative regulations that require explanations of _why_ a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision."
]
},
{
"cell_type": "markdown",
"id": "3596d952",
"metadata": {
"papermill": {
"duration": 0.042869,
"end_time": "2022-04-18T15:54:03.346879",
"exception": false,
"start_time": "2022-04-18T15:54:03.304010",
"status": "completed"
},
"tags": []
},
"source": [
"Kernel SHAP algorithm requires a baseline (also known as background dataset). If not provided, a baseline is calculated automatically by SageMaker Clarify using K-means or K-prototypes in the input dataset. Baseline dataset type shall be the same as `dataset_type` of `DataConfig`, and baseline samples shall only include features. By definition, `baseline` should either be a S3 URI to the baseline dataset file, or an in-place list of samples. In this case we chose the latter, and put the first sample of the test dataset to the list."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8efafe15",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:54:03.444618Z",
"iopub.status.busy": "2022-04-18T15:54:03.443803Z",
"iopub.status.idle": "2022-04-18T15:54:03.446753Z",
"shell.execute_reply": "2022-04-18T15:54:03.446309Z"
},
"papermill": {
"duration": 0.052082,
"end_time": "2022-04-18T15:54:03.446918",
"exception": false,
"start_time": "2022-04-18T15:54:03.394836",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"shap_config = clarify.SHAPConfig(\n",
" baseline=[test_features.iloc[0].values.tolist()],\n",
" num_samples=15,\n",
" agg_method=\"mean_abs\",\n",
" save_local_shap_values=True,\n",
")\n",
"\n",
"explainability_output_path = \"s3://{}/{}/clarify-explainability\".format(bucket, prefix)\n",
"explainability_data_config = clarify.DataConfig(\n",
" s3_data_input_path=train_uri,\n",
" s3_output_path=explainability_output_path,\n",
" label=\"Target\",\n",
" headers=training_data.columns.to_list(),\n",
" dataset_type=\"text/csv\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2a001d69",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T15:54:03.545718Z",
"iopub.status.busy": "2022-04-18T15:54:03.537884Z",
"iopub.status.idle": "2022-04-18T16:07:30.752399Z",
"shell.execute_reply": "2022-04-18T16:07:30.750754Z"
},
"papermill": {
"duration": 807.262471,
"end_time": "2022-04-18T16:07:30.752546",
"exception": false,
"start_time": "2022-04-18T15:54:03.490075",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Job Name: Clarify-Explainability-2022-04-18-15-54-03-536\n",
"Inputs: [{'InputName': 'dataset', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/train_data.csv', 'LocalPath': '/opt/ml/processing/input/data', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}, {'InputName': 'analysis_config', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-explainability/analysis_config.json', 'LocalPath': '/opt/ml/processing/input/config', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}]\n",
"Outputs: [{'OutputName': 'analysis_result', 'AppManaged': False, 'S3Output': {'S3Uri': 's3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-explainability', 'LocalPath': '/opt/ml/processing/output', 'S3UploadMode': 'EndOfJob'}}]\n",
".............................\u001b[34m2022-04-18 15:58:37,472 logging.conf not found when configuring logging, using default logging configuration.\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 Starting SageMaker Clarify Processing job\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 Analysis config path: /opt/ml/processing/input/config/analysis_config.json\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 Analysis result path: /opt/ml/processing/output\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 This host is algo-1.\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 This host is the leader.\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,473 Number of hosts in the cluster is 1.\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,640 Running Python / Pandas based analyzer.\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,640 Dataset type: text/csv uri: /opt/ml/processing/input/data\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,651 Loading dataset...\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,687 Loaded dataset. Dataset info:\u001b[0m\n",
"\u001b[34m\u001b[0m\n",
"\u001b[34mRangeIndex: 30162 entries, 0 to 30161\u001b[0m\n",
"\u001b[34mData columns (total 14 columns):\n",
" # Column Non-Null Count Dtype\u001b[0m\n",
"\u001b[34m--- ------ -------------- -----\n",
" 0 Age 30162 non-null int64\n",
" 1 Workclass 30162 non-null int64\n",
" 2 fnlwgt 30162 non-null int64\n",
" 3 Education 30162 non-null int64\n",
" 4 Education-Num 30162 non-null int64\n",
" 5 Marital Status 30162 non-null int64\n",
" 6 Occupation 30162 non-null int64\n",
" 7 Relationship 30162 non-null int64\n",
" 8 Ethnic group 30162 non-null int64\n",
" 9 Sex 30162 non-null int64\n",
" 10 Capital Gain 30162 non-null int64\n",
" 11 Capital Loss 30162 non-null int64\n",
" 12 Hours per week 30162 non-null int64\n",
" 13 Country 30162 non-null int64\u001b[0m\n",
"\u001b[34mdtypes: int64(14)\u001b[0m\n",
"\u001b[34mmemory usage: 3.2 MB\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,835 Spinning up shadow endpoint\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,835 Creating endpoint-config with name sm-clarify-config-1650297517-31f2\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:37,915 Creating endpoint: 'sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297517-f947'\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:38,339 Using endpoint name: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297517-f947\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:38,339 Waiting for endpoint ...\u001b[0m\n",
"\u001b[34m2022-04-18 15:58:38,339 Checking endpoint status:\u001b[0m\n",
"\u001b[34mLegend:\u001b[0m\n",
"\u001b[34m(OutOfService: x, Creating: -, Updating: -, InService: !, RollingBack: <, Deleting: o, Failed: *)\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,784 Endpoint is in service after 180 seconds\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,785 Endpoint ready.\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,786 SHAP n_samples 15\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,904 =====================================================\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,904 Shap analyzer: explaining 30162 rows, 14 columns...\u001b[0m\n",
"\u001b[34m2022-04-18 16:01:38,904 =====================================================\u001b[0m\n",
"\u001b[34m 0% (0 of 30162) | | Elapsed Time: 0:00:00 ETA: --:--:--\u001b[0m\n",
"\u001b[34m 9% (2798 of 30162) |# | Elapsed Time: 0:00:30 ETA: 0:04:53\u001b[0m\n",
"\u001b[34m 18% (5681 of 30162) |### | Elapsed Time: 0:01:00 ETA: 0:04:14\u001b[0m\n",
"\u001b[34m 28% (8703 of 30162) |##### | Elapsed Time: 0:01:30 ETA: 0:03:33\u001b[0m\n",
"\u001b[34m 39% (11989 of 30162) |####### | Elapsed Time: 0:02:00 ETA: 0:02:45\u001b[0m\n",
"\u001b[34m 51% (15453 of 30162) |######### | Elapsed Time: 0:02:30 ETA: 0:02:07\u001b[0m\n",
"\u001b[34m 62% (18999 of 30162) |########### | Elapsed Time: 0:03:00 ETA: 0:01:34\u001b[0m\n",
"\u001b[34m 74% (22510 of 30162) |############# | Elapsed Time: 0:03:30 ETA: 0:01:05\u001b[0m\n",
"\u001b[34m 86% (25999 of 30162) |############### | Elapsed Time: 0:04:00 ETA: 0:00:35\u001b[0m\n",
"\u001b[34m 97% (29442 of 30162) |################# | Elapsed Time: 0:04:30 ETA: 0:00:06\u001b[0m\n",
"\u001b[34m100% (30162 of 30162) |##################| Elapsed Time: 0:04:36 Time: 0:04:36\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:15,695 getting explanations took 276.79 seconds.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:15,695 ===================================================\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,084 converting explanations to tabular took 2.39 seconds.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,084 ===================================================\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,088 Wrote baseline used to compute explanations to: /opt/ml/processing/output/explanations_shap/baseline.csv\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,731 Wrote 30162 local explanations to: /opt/ml/processing/output/explanations_shap/out.csv\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,731 writing local explanations took 0.65 seconds.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,731 ===================================================\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,735 aggregating local explanations took 0.00 seconds.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,735 ===================================================\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,736 Shap analysis finished.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,736 Stop using endpoint: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297517-f947\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,736 Deleting endpoint configuration with name: sm-clarify-config-1650297517-31f2\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,853 Deleting endpoint with name: sm-clarify-DEMO-clarify-model-18-04-2022-15-45--1650297517-f947\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:18,997 Model endpoint delivered 107.74733 requests per second and a total of 30164 requests over 280 seconds\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:26,813 Stop using endpoint: None\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:50,264 jupyter nbconvert --to html --output /opt/ml/processing/output/report.html /opt/ml/processing/output/report.ipynb --template sagemaker-xai\u001b[0m\n",
"\u001b[34m[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html\u001b[0m\n",
"\u001b[34m[NbConvertApp] Writing 570437 bytes to /opt/ml/processing/output/report.html\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:51,317 HTML report '/opt/ml/processing/output/report.html' generated successfully.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:51,317 wkhtmltopdf -q /opt/ml/processing/output/report.html /opt/ml/processing/output/report.pdf\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:51,874 PDF report '/opt/ml/processing/output/report.pdf' generated successfully.\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:51,875 Collected analyses: \u001b[0m\n",
"\u001b[34m{\n",
" \"version\": \"1.0\",\n",
" \"explanations\": {\n",
" \"kernel_shap\": {\n",
" \"label0\": {\n",
" \"global_shap_values\": {\n",
" \"Age\": 0.03655626472022009,\n",
" \"Workclass\": 0.017905832546722414,\n",
" \"fnlwgt\": 0.021385894167534045,\n",
" \"Education\": 0.018535316056790388,\n",
" \"Education-Num\": 0.03609330944536093,\n",
" \"Marital Status\": 0.02898470399620499,\n",
" \"Occupation\": 0.026482349644306062,\n",
" \"Relationship\": 0.03615980532944972,\n",
" \"Ethnic group\": 0.020033663785676746,\n",
" \"Sex\": 0.017880631469685705,\n",
" \"Capital Gain\": 0.033581907850084025,\n",
" \"Capital Loss\": 0.019556674983842386,\n",
" \"Hours per week\": 0.021283579738450336,\n",
" \"Country\": 0.04712774225543154\n",
" },\n",
" \"expected_value\": 0.0006380207487381995\n",
" }\n",
" }\n",
" }\u001b[0m\n",
"\u001b[34m}\u001b[0m\n",
"\u001b[34m2022-04-18 16:06:51,875 exit_message: Completed: SageMaker XAI Analyzer ran successfully\u001b[0m\n",
"\u001b[34m---!\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"clarify_processor.run_explainability(\n",
" data_config=explainability_data_config,\n",
" model_config=model_config,\n",
" explainability_config=shap_config,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "33d4727d",
"metadata": {
"papermill": {
"duration": 0.066023,
"end_time": "2022-04-18T16:07:30.893004",
"exception": false,
"start_time": "2022-04-18T16:07:30.826981",
"status": "completed"
},
"tags": []
},
"source": [
"#### Viewing the Explainability Report\n",
"As with the bias report, you can view the explainability report in Studio under the experiments tab.\n",
"\n",
"\n",
" \n",
"\n",
"The Model Insights tab contains direct links to the report and model insights.\n",
"\n",
"If you're not a Studio user yet, as with the Bias Report, you can access this report at the following S3 bucket."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "5e571398",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:31.028316Z",
"iopub.status.busy": "2022-04-18T16:07:31.027352Z",
"iopub.status.idle": "2022-04-18T16:07:31.030696Z",
"shell.execute_reply": "2022-04-18T16:07:31.031231Z"
},
"papermill": {
"duration": 0.077024,
"end_time": "2022-04-18T16:07:31.031371",
"exception": false,
"start_time": "2022-04-18T16:07:30.954347",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'s3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-sagemaker-clarify/clarify-explainability'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explainability_output_path"
]
},
{
"cell_type": "markdown",
"id": "57a8f107",
"metadata": {
"papermill": {
"duration": 0.061215,
"end_time": "2022-04-18T16:07:31.154052",
"exception": false,
"start_time": "2022-04-18T16:07:31.092837",
"status": "completed"
},
"tags": []
},
"source": [
"#### Analysis of local explanations\n",
"It is possible to visualize the the local explanations for single examples in your dataset. You can use the obtained results from running Kernel SHAP algorithm for global explanations.\n",
"\n",
"You can simply load the local explanations stored in your output path, and visualize the explanation (i.e., the impact that the single features have on the prediction of your model) for any single example."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "90a09a11",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:31.293968Z",
"iopub.status.busy": "2022-04-18T16:07:31.292566Z",
"iopub.status.idle": "2022-04-18T16:07:31.821690Z",
"shell.execute_reply": "2022-04-18T16:07:31.821217Z"
},
"papermill": {
"duration": 0.598354,
"end_time": "2022-04-18T16:07:31.821806",
"exception": false,
"start_time": "2022-04-18T16:07:31.223452",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example number: 111 \n",
"with model prediction: False\n",
"\n",
"Feature values -- Label Target 0\n",
"Age 21\n",
"Workclass 2\n",
"fnlwgt 199915\n",
"Education 15\n",
"Education-Num 10\n",
"Marital Status 4\n",
"Occupation 7\n",
"Relationship 3\n",
"Ethnic group 4\n",
"Sex 0\n",
"Capital Gain 0\n",
"Capital Loss 0\n",
"Hours per week 40\n",
"Country 38\n",
"Name: 120, dtype: int64\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"local_explanations_out = pd.read_csv(explainability_output_path + \"/explanations_shap/out.csv\")\n",
"feature_names = [str.replace(c, \"_label0\", \"\") for c in local_explanations_out.columns.to_series()]\n",
"local_explanations_out.columns = feature_names\n",
"\n",
"selected_example = 111\n",
"print(\n",
" \"Example number:\",\n",
" selected_example,\n",
" \"\\nwith model prediction:\",\n",
" sum(local_explanations_out.iloc[selected_example]) > 0,\n",
")\n",
"print(\"\\nFeature values -- Label\", training_data.iloc[selected_example])\n",
"local_explanations_out.iloc[selected_example].plot(\n",
" kind=\"bar\", title=\"Local explanation for the example number \" + str(selected_example), rot=90\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5d169fff",
"metadata": {
"papermill": {
"duration": 0.054899,
"end_time": "2022-04-18T16:07:31.931845",
"exception": false,
"start_time": "2022-04-18T16:07:31.876946",
"status": "completed"
},
"tags": []
},
"source": [
"### Clean Up\n",
"Finally, don't forget to clean up the resources we set up and used for this demo!"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "4a66aede",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:32.047815Z",
"iopub.status.busy": "2022-04-18T16:07:32.047054Z",
"iopub.status.idle": "2022-04-18T16:07:32.151107Z",
"shell.execute_reply": "2022-04-18T16:07:32.150588Z"
},
"papermill": {
"duration": 0.164477,
"end_time": "2022-04-18T16:07:32.151256",
"exception": false,
"start_time": "2022-04-18T16:07:31.986779",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"session.delete_model(model_name)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notebook CI Test Results\n",
"\n",
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
},
"papermill": {
"default_parameters": {},
"duration": 1526.483007,
"end_time": "2022-04-18T16:07:32.926299",
"environment_variables": {},
"exception": null,
"input_path": "fairness_and_explainability.ipynb",
"output_path": "/opt/ml/processing/output/fairness_and_explainability-2022-04-18-15-28-21.ipynb",
"parameters": {
"kms_key": "arn:aws:kms:us-west-2:000000000000:1234abcd-12ab-34cd-56ef-1234567890ab"
},
"start_time": "2022-04-18T15:42:06.443292",
"version": "2.3.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}