{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Responsible AI - Exploratory Data Analysis\n", "\n", "This notebook shows how to quantify and visualize correlations (scatter plots, correlation matrix) and generate descriptive statistics (histograms and bar plots). To measure bias before training a model, we will use $CI_{norm}$ and $DPL$. These are just example measure that can be used pre-training. Make sure to try more measures when building a model.\n", "\n", "__Dataset:__ \n", "You will download a dataset for this exercise using [folktables](https://github.com/zykls/folktables). Folktables provides an API to download data from the American Community Survey (ACS) Public Use Microdata Sample (PUMS) files which are managed by the US Census Bureau. The data itself is governed by the terms of use provided by the Census Bureau. For more information, see the [Terms of Service](https://www.census.gov/data/developers/about/terms-of-service.html).\n", "\n", "\n", "__ML Problem:__ \n", "Ultimately, the goal will be to predict whether an individual's income is above \\\\$50,000. We will filter the ACS PUMS data sample to only include individuals above the age of 16, who reported usual working hours of at least 1 hour per week in the past year, and an income of at least \\\\$100. The threshold of \\\\$50,000 was chosen so that this dataset can serve as a comparable substitute to the [UCI Adult dataset](https://archive.ics.uci.edu/ml/datasets/adult). The income threshold can be changed easily to define new prediction tasks.\n", "\n", "__Table of contents__\n", "\n", "1. Loading Data\n", "2. Data Overview\n", "3. Bar Plots \\& Histograms\n", "4. Scatter Plots\n", "5. Correlation Matrix\n", "5. $CI_{norm}$ and $DPL$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook assumes an installation of the SageMaker kernel `conda_pytorch_p39`. In addition, libraries from a requirements.txt need to be installed:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install --no-deps -U -q -r ../../requirements.txt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Reshaping/basic libraries\n", "import pandas as pd\n", "\n", "# Plotting libraries\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "sns.set_style(\"darkgrid\", {\"axes.facecolor\": \".9\"})\n", "\n", "# Operational libraries\n", "import sys\n", "\n", "sys.path.append(\"..\")\n", "\n", "# Fairness libraries\n", "from folktables.acs import *\n", "from folktables.folktables import *\n", "from folktables.load_acs import *\n", "\n", "# Jupyter(lab) libraries\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Loading Data\n", "(Go to top)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To read in the dataset, we will be using [folktables](https://github.com/zykls/folktables) which provides access to the US Census dataset. Folktables contains predefined prediction tasks but also allows the user to specify the problem type." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The US Census dataset distinguishes between household and individuals. To obtain data on individuals, we use `ACSDataSource` with `survey=person`. The feature names for the US Census data follow the same distinction and use `P` for `person` and `H` for `household`, e.g.: `AGEP` refers to age of an individual." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data for 2018 1-Year person survey for CA...\n" ] } ], "source": [ "income_features = [\n", " \"AGEP\", # age individual\n", " \"COW\", # class of worker\n", " \"SCHL\", # educational attainment\n", " \"MAR\", # marital status\n", " \"OCCP\", # occupation\n", " \"POBP\", # place of birth\n", " \"RELP\", # relationship\n", " \"WKHP\", # hours worked per week past 12 months\n", " \"SEX\", # sex\n", " \"RAC1P\", # recorded detailed race code\n", " \"PWGTP\", # persons weight\n", " \"GCL\", # grandparents living with grandchildren\n", " \"SCH\", # school enrollment\n", "]\n", "\n", "# Define the prediction problem and features\n", "ACSIncome = folktables.BasicProblem(\n", " features=income_features,\n", " target=\"PINCP\", # total persons income\n", " target_transform=lambda x: x > 50000,\n", " group=\"RAC1P\",\n", " preprocess=adult_filter, # applies the following conditions; ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))\n", " postprocess=lambda x: x, # applies post processing, e.g. fill all NAs\n", ")\n", "\n", "# Initialize year, duration (\"1-Year\" or \"5-Year\") and granularity (household or person)\n", "data_source = ACSDataSource(survey_year=\"2018\", horizon=\"1-Year\", survey=\"person\")\n", "# Specify region (here: California) and load data\n", "ca_data = data_source.get_data(states=[\"CA\"], download=True)\n", "# Apply transformation as per problem statement above\n", "ca_features, ca_labels, ca_group = ACSIncome.df_to_numpy(ca_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Data Overview\n", "(Go to top)\n", "\n", "We want to go through basic steps of exploratory data analysis (EDA), performing initial data investigations to discover patterns, spot anomalies, and look for insights to inform later ML modeling choices." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | AGEP | \n", "COW | \n", "SCHL | \n", "MAR | \n", "OCCP | \n", "POBP | \n", "RELP | \n", "WKHP | \n", "SEX | \n", "RAC1P | \n", "PWGTP | \n", "GCL | \n", "SCH | \n", ">50k | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "30.0 | \n", "6.0 | \n", "14.0 | \n", "1.0 | \n", "9610.0 | \n", "6.0 | \n", "16.0 | \n", "40.0 | \n", "1.0 | \n", "8.0 | \n", "32.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "
1 | \n", "21.0 | \n", "4.0 | \n", "16.0 | \n", "5.0 | \n", "1970.0 | \n", "6.0 | \n", "17.0 | \n", "20.0 | \n", "1.0 | \n", "1.0 | \n", "52.0 | \n", "NaN | \n", "2.0 | \n", "0.0 | \n", "
2 | \n", "65.0 | \n", "2.0 | \n", "22.0 | \n", "5.0 | \n", "2040.0 | \n", "6.0 | \n", "17.0 | \n", "8.0 | \n", "1.0 | \n", "1.0 | \n", "33.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "
3 | \n", "33.0 | \n", "1.0 | \n", "14.0 | \n", "3.0 | \n", "9610.0 | \n", "36.0 | \n", "16.0 | \n", "40.0 | \n", "1.0 | \n", "1.0 | \n", "53.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "
4 | \n", "18.0 | \n", "2.0 | \n", "19.0 | \n", "5.0 | \n", "1021.0 | \n", "6.0 | \n", "17.0 | \n", "18.0 | \n", "2.0 | \n", "1.0 | \n", "106.0 | \n", "NaN | \n", "3.0 | \n", "0.0 | \n", "