{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Responsible AI - SageMaker Clarify\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Amazon SageMaker Clarify](https://aws.amazon.com/sagemaker/clarify/) helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. \n", "\n", "In this notebook, we highlight how you can use SageMaker to train models, host them an inference endpoint, and provide bias detection and explainability to analyze data and understand prediction outcomes from the model.\n", "\n", "\n", "__Dataset:__ \n", "The dataset we will use for this exercise is coming from [folktables](https://github.com/zykls/folktables). Folktables provide code to download data from the American Community Survey (ACS) Public Use Microdata Sample (PUMS) files managed by the US Census Bureau. The data itself is governed by the terms of use provided by the Census Bureau. For more information, see the [Terms of Service](https://www.census.gov/data/developers/about/terms-of-service.html).\n", "\n", "__ML Problem:__ \n", "Ultimately, the goal will be to predict whether an individual's income is above \\\\$50,000. We will filter the ACS PUMS data sample to only include individuals above the age of 16, who reported usual working hours of at least 1 hour per week in the past year, and an income of at least \\\\$100. The threshold of \\\\$50,000 was chosen so that this dataset can serve as a comparable substitute to the [UCI Adult dataset](https://archive.ics.uci.edu/ml/datasets/adult). The income threshold can be changed easily to define new prediction tasks.\n", "\n", "\n", "1. Read the dataset\n", "2. Data Processing\n", " * Exploratory Data Analysis\n", " * Select features to build the model\n", " * Feature Transformation\n", " * Train - Validation - Test Datasets\n", " * Data processing with Pipeline and ColumnTransformer\n", "3. Train (and Tune) a Classifier\n", "4. Amazon SageMaker Clarify" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook assumes availability of the SageMaker kernel `conda_pyhon3`. In addition, install folktables and sklearn:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install -U -q scikit-learn==1.1.3\n", "!pip install -U -q --no-deps folktables==0.0.11" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Reshaping/basic libraries\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# Plotting libraries\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "sns.set_style(\"darkgrid\", {\"axes.facecolor\": \".9\"})\n", "\n", "# ML libraries\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "# Operational libraries\n", "import sys\n", "from io import StringIO\n", "import os\n", "import time\n", "import sys\n", "import IPython\n", "from time import gmtime, strftime\n", "from datetime import datetime, timedelta\n", "\n", "sys.path.append(\"..\")\n", "\n", "# Fairness libraries\n", "from folktables.acs import *\n", "from folktables.folktables import *\n", "from folktables.load_acs import *\n", "\n", "# Jupyter(lab) libraries\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "\n", "# SageMaker and connection libraries\n", "import boto3\n", "import urllib\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker import Session\n", "from sagemaker.amazon.amazon_estimator import get_image_uri\n", "from sagemaker.inputs import TrainingInput\n", "from sagemaker.serializers import CSVSerializer\n", "from sagemaker.s3 import S3Downloader, S3Uploader\n", "from sagemaker import clarify\n", "from sagemaker import model_monitor\n", "from datetime import date\n", "\n", "today = date.today()\n", "bucket = sagemaker.Session().default_bucket()\n", "prfx = \"sagemaker/sagemaker-clarify-income-model\"\n", "region = boto3.Session().region_name\n", "client = boto3.client('sagemaker')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Read the dataset\n", "(Go to top)\n", "\n", "To read in the dataset, we will be using [folktables](https://github.com/zykls/folktables) which provides access to the US Census dataset. Folktables contains predefined prediction tasks but also allows the user to specify the problem type.\n", "\n", "The US Census dataset distinguishes between household and individuals. To obtain data on individuals, we use `ACSDataSource` with `survey=person`. The feature names for the US Census data follow the same distinction and use `P` for `person` and `H` for `household`, e.g.: `AGEP` refers to age of an individual." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "income_features = [\n", " \"AGEP\", # age individual\n", " \"COW\", # class of worker\n", " \"SCHL\", # educational attainment\n", " \"MAR\", # marital status\n", " \"OCCP\", # occupation\n", " \"POBP\", # place of birth\n", " \"RELP\", # relationship\n", " \"WKHP\", # hours worked per week past 12 months\n", " \"SEX\", # sex\n", " \"RAC1P\", # recorded detailed race code\n", " \"PWGTP\", # persons weight\n", " \"GCL\", # grand parents living with granchildren\n", " \"SCH\", # school enrollment\n", "]\n", "\n", "# Define the prediction problem and features\n", "ACSIncome = folktables.BasicProblem(\n", " features=income_features,\n", " target=\"PINCP\", # total persons income\n", " target_transform=lambda x: x > 50000,\n", " group=\"RAC1P\",\n", " preprocess=adult_filter, # applies the following conditions; ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))\n", " postprocess=lambda x: x, # applies post processing, e.g. fill all NAs\n", ")\n", "\n", "# Initialize year, duration (\"1-Year\" or \"5-Year\") and granularity (household or person)\n", "data_source = ACSDataSource(survey_year=\"2018\", horizon=\"1-Year\", survey=\"person\")\n", "# Specify region (here: California) and load data\n", "ca_data = data_source.get_data(states=[\"CA\"], download=True)\n", "# Apply transformation as per problem statement above\n", "ca_features, ca_labels, ca_group = ACSIncome.df_to_numpy(ca_data)\n", "\n", "# Convert numpy array to dataframe\n", "df = pd.DataFrame(\n", " np.concatenate((ca_features, ca_labels.reshape(-1, 1)), axis=1),\n", " columns=income_features + [\">50k\"],\n", ")\n", "\n", "# For further modelling we want to use only 2 groups (see DATAPREP notebook for details)\n", "df = df[df[\"RAC1P\"].isin([6, 8])].copy(deep=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Data Processing\n", "(Go to top)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Exploratory Data Analysis\n", "(Go to Data Processing)\n", "\n", "We look at number of rows, columns, and some simple statistics of the dataset." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | AGEP | \n", "COW | \n", "SCHL | \n", "MAR | \n", "OCCP | \n", "POBP | \n", "RELP | \n", "WKHP | \n", "SEX | \n", "RAC1P | \n", "PWGTP | \n", "GCL | \n", "SCH | \n", ">50k | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "30.0 | \n", "6.0 | \n", "14.0 | \n", "1.0 | \n", "9610.0 | \n", "6.0 | \n", "16.0 | \n", "40.0 | \n", "1.0 | \n", "8.0 | \n", "32.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "
27 | \n", "23.0 | \n", "2.0 | \n", "21.0 | \n", "5.0 | \n", "2545.0 | \n", "207.0 | \n", "17.0 | \n", "20.0 | \n", "2.0 | \n", "6.0 | \n", "35.0 | \n", "NaN | \n", "3.0 | \n", "0.0 | \n", "
33 | \n", "18.0 | \n", "1.0 | \n", "16.0 | \n", "5.0 | \n", "9610.0 | \n", "6.0 | \n", "17.0 | \n", "8.0 | \n", "2.0 | \n", "6.0 | \n", "33.0 | \n", "NaN | \n", "2.0 | \n", "0.0 | \n", "
46 | \n", "40.0 | \n", "1.0 | \n", "15.0 | \n", "3.0 | \n", "4140.0 | \n", "303.0 | \n", "16.0 | \n", "22.0 | \n", "1.0 | \n", "8.0 | \n", "38.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "
49 | \n", "18.0 | \n", "1.0 | \n", "18.0 | \n", "5.0 | \n", "725.0 | \n", "6.0 | \n", "17.0 | \n", "12.0 | \n", "2.0 | \n", "6.0 | \n", "60.0 | \n", "NaN | \n", "2.0 | \n", "0.0 | \n", "