{
"cells": [
{
"cell_type": "markdown",
"id": "29ccd481",
"metadata": {},
"source": [
"# XAI(Explainable AI): Kernel SHAP for Classification\n",
"\n",
"본 핸즈온에서는 앙상블과 같은 블랙 박스 모델을 설명하는 데 적합한 SHAP(SHapley Additive exPlanations)을 사용하는 예시를 보여줍니다. SHAP은 전체 셋의 feature importance가 아니라 각 샘플 데이터마다 예측에 얼마나 기여했는지 정량화가 가능합니다.\n",
"\n",
"## SHAP(SHapley Additive exPlanations)\n",
"\n",
"SHAP에 대한 심화 주제는 아래 논문과 링크를 참조하세요\n",
"- [A Unified Approach to Interpreting Model Predictions] Lundberg, Scott M., and Su-In Lee Advances in Neural Information Processing Systems. 2017.\n",
"- Interpretable ML Book (SHAP chapter): http://christophm.github.io/interpretable-ml-book/shap.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a545713e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\n",
"You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"!pip install -qU shap"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a7184265",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from autogluon.tabular import TabularDataset, TabularPredictor\n",
"import pandas as pd\n",
"import numpy as np\n",
"import sklearn\n",
"import shap\n",
"shap.initjs()\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "markdown",
"id": "d268416b",
"metadata": {},
"source": [
" \n",
"\n",
"## 1. Data preparation and Training"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4817b39",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" age \n",
" workclass \n",
" fnlwgt \n",
" education \n",
" education-num \n",
" marital-status \n",
" occupation \n",
" relationship \n",
" race \n",
" sex \n",
" capital-gain \n",
" capital-loss \n",
" hours-per-week \n",
" native-country \n",
" class \n",
" \n",
" \n",
" \n",
" \n",
" 26802 \n",
" 55 \n",
" Self-emp-not-inc \n",
" 319883 \n",
" Masters \n",
" 14 \n",
" Married-civ-spouse \n",
" Exec-managerial \n",
" Husband \n",
" White \n",
" Male \n",
" 4386 \n",
" 0 \n",
" 10 \n",
" ? \n",
" >50K \n",
" \n",
" \n",
" 19134 \n",
" 78 \n",
" ? \n",
" 167336 \n",
" HS-grad \n",
" 9 \n",
" Married-civ-spouse \n",
" ? \n",
" Husband \n",
" White \n",
" Male \n",
" 0 \n",
" 0 \n",
" 16 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 37431 \n",
" 36 \n",
" Private \n",
" 190350 \n",
" HS-grad \n",
" 9 \n",
" Never-married \n",
" Adm-clerical \n",
" Not-in-family \n",
" Black \n",
" Female \n",
" 0 \n",
" 0 \n",
" 40 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 20173 \n",
" 25 \n",
" Self-emp-inc \n",
" 160261 \n",
" Bachelors \n",
" 13 \n",
" Never-married \n",
" Exec-managerial \n",
" Own-child \n",
" Asian-Pac-Islander \n",
" Male \n",
" 0 \n",
" 0 \n",
" 35 \n",
" Taiwan \n",
" <=50K \n",
" \n",
" \n",
" 3869 \n",
" 47 \n",
" Private \n",
" 216096 \n",
" Some-college \n",
" 10 \n",
" Married-spouse-absent \n",
" Exec-managerial \n",
" Unmarried \n",
" White \n",
" Female \n",
" 0 \n",
" 0 \n",
" 35 \n",
" Puerto-Rico \n",
" <=50K \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"26802 55 Self-emp-not-inc 319883 Masters 14 \n",
"19134 78 ? 167336 HS-grad 9 \n",
"37431 36 Private 190350 HS-grad 9 \n",
"20173 25 Self-emp-inc 160261 Bachelors 13 \n",
"3869 47 Private 216096 Some-college 10 \n",
"\n",
" marital-status occupation relationship \\\n",
"26802 Married-civ-spouse Exec-managerial Husband \n",
"19134 Married-civ-spouse ? Husband \n",
"37431 Never-married Adm-clerical Not-in-family \n",
"20173 Never-married Exec-managerial Own-child \n",
"3869 Married-spouse-absent Exec-managerial Unmarried \n",
"\n",
" race sex capital-gain capital-loss \\\n",
"26802 White Male 4386 0 \n",
"19134 White Male 0 0 \n",
"37431 Black Female 0 0 \n",
"20173 Asian-Pac-Islander Male 0 0 \n",
"3869 White Female 0 0 \n",
"\n",
" hours-per-week native-country class \n",
"26802 10 ? >50K \n",
"19134 16 United-States <=50K \n",
"37431 40 United-States <=50K \n",
"20173 35 Taiwan <=50K \n",
"3869 35 Puerto-Rico <=50K "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"N_SUBSAMPLE = 500 # subsample datasets for faster demo\n",
"N_TEST = 50\n",
"NSHAP_SAMPLES = 10 # how many samples to use to approximate each Shapely value, larger values will be slower\n",
"\n",
"train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv') # can be local CSV file as well, returns Pandas DataFrame\n",
"train_data = train_data.sample(N_SUBSAMPLE)\n",
"test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')\n",
"test_data = test_data.sample(N_TEST)\n",
"\n",
"label = 'class'\n",
"\n",
"y_train = train_data[label]\n",
"y_test = test_data[label]\n",
"X_train = pd.DataFrame(train_data.drop(columns=[label]))\n",
"X_test = pd.DataFrame(test_data.drop(columns=[label]))\n",
"\n",
"display(train_data.head())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "676c8c50",
"metadata": {},
"outputs": [],
"source": [
"save_path = 'ag-03-explainable-ai'\n",
"!rm -rf $save_path"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9f9f9a2c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Beginning AutoGluon training ... Time limit = 20s\n",
"AutoGluon will save models to \"ag-03-explainable-ai/\"\n",
"AutoGluon Version: 0.5.2\n",
"Python Version: 3.8.12\n",
"Operating System: Linux\n",
"Train Data Rows: 500\n",
"Train Data Columns: 14\n",
"Label Column: class\n",
"Preprocessing data ...\n",
"Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K\n",
"\tNote: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.\n",
"\tTo explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.\n",
"Using Feature Generators to preprocess the data ...\n",
"Fitting AutoMLPipelineFeatureGenerator...\n",
"\tAvailable Memory: 11886.91 MB\n",
"\tTrain Data (Original) Memory Usage: 0.29 MB (0.0% of available memory)\n",
"\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
"\tStage 1 Generators:\n",
"\t\tFitting AsTypeFeatureGenerator...\n",
"\t\t\tNote: Converting 1 features to boolean dtype as they only contain 2 unique values.\n",
"\tStage 2 Generators:\n",
"\t\tFitting FillNaFeatureGenerator...\n",
"\tStage 3 Generators:\n",
"\t\tFitting IdentityFeatureGenerator...\n",
"\t\tFitting CategoryFeatureGenerator...\n",
"\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
"\tStage 4 Generators:\n",
"\t\tFitting DropUniqueFeatureGenerator...\n",
"\tTypes of features in original data (raw dtype, special dtypes):\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]\n",
"\tTypes of features in processed data (raw dtype, special dtypes):\n",
"\t\t('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('int', ['bool']) : 1 | ['sex']\n",
"\t0.2s = Fit runtime\n",
"\t14 features in original data used to generate 14 features in processed data.\n",
"\tTrain Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)\n",
"Data preprocessing and feature engineering runtime = 0.28s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
"\tTo change this, specify the eval_metric parameter of Predictor()\n",
"Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100\n",
"Fitting 13 L1 models ...\n",
"Fitting model: KNeighborsUnif ... Training model for up to 19.72s of the 19.72s of remaining time.\n",
"\t0.66\t = Validation score (accuracy)\n",
"\t0.01s\t = Training runtime\n",
"\t0.04s\t = Validation runtime\n",
"Fitting model: KNeighborsDist ... Training model for up to 19.67s of the 19.66s of remaining time.\n",
"\t0.64\t = Validation score (accuracy)\n",
"\t0.01s\t = Training runtime\n",
"\t0.06s\t = Validation runtime\n",
"Fitting model: LightGBMXT ... Training model for up to 19.6s of the 19.6s of remaining time.\n",
"\t0.84\t = Validation score (accuracy)\n",
"\t1.28s\t = Training runtime\n",
"\t0.04s\t = Validation runtime\n",
"Fitting model: LightGBM ... Training model for up to 18.27s of the 18.26s of remaining time.\n",
"\t0.83\t = Validation score (accuracy)\n",
"\t0.52s\t = Training runtime\n",
"\t0.04s\t = Validation runtime\n",
"Fitting model: RandomForestGini ... Training model for up to 17.69s of the 17.69s of remaining time.\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t0.9s\t = Training runtime\n",
"\t0.08s\t = Validation runtime\n",
"Fitting model: RandomForestEntr ... Training model for up to 16.68s of the 16.68s of remaining time.\n",
"\t0.81\t = Validation score (accuracy)\n",
"\t0.92s\t = Training runtime\n",
"\t0.09s\t = Validation runtime\n",
"Fitting model: CatBoost ... Training model for up to 15.64s of the 15.64s of remaining time.\n",
"\t0.85\t = Validation score (accuracy)\n",
"\t3.64s\t = Training runtime\n",
"\t0.03s\t = Validation runtime\n",
"Fitting model: ExtraTreesGini ... Training model for up to 11.97s of the 11.97s of remaining time.\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t0.84s\t = Training runtime\n",
"\t0.17s\t = Validation runtime\n",
"Fitting model: ExtraTreesEntr ... Training model for up to 10.91s of the 10.91s of remaining time.\n",
"\t0.83\t = Validation score (accuracy)\n",
"\t0.92s\t = Training runtime\n",
"\t0.09s\t = Validation runtime\n",
"Fitting model: NeuralNetFastAI ... Training model for up to 9.87s of the 9.87s of remaining time.\n",
"No improvement since epoch 5: early stopping\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t2.48s\t = Training runtime\n",
"\t0.05s\t = Validation runtime\n",
"Fitting model: XGBoost ... Training model for up to 7.31s of the 7.3s of remaining time.\n",
"\t0.86\t = Validation score (accuracy)\n",
"\t1.07s\t = Training runtime\n",
"\t0.03s\t = Validation runtime\n",
"Fitting model: NeuralNetTorch ... Training model for up to 6.2s of the 6.19s of remaining time.\n",
"\tRan out of time, stopping training early. (Stopping on epoch 24)\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t6.23s\t = Training runtime\n",
"\t0.04s\t = Validation runtime\n",
"Fitting model: WeightedEnsemble_L2 ... Training model for up to 19.72s of the -0.11s of remaining time.\n",
"\t0.87\t = Validation score (accuracy)\n",
"\t0.64s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"AutoGluon training complete, total runtime = 20.81s ... Best model: \"WeightedEnsemble_L2\"\n",
"TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-03-explainable-ai/\")\n"
]
}
],
"source": [
"predictor = TabularPredictor(label=label, path=save_path, problem_type='binary').fit(train_data, time_limit=20)"
]
},
{
"cell_type": "markdown",
"id": "3b0b9ccd",
"metadata": {},
"source": [
" \n",
"\n",
"## 2. Explain predictions\n",
"\n",
"SHAP은 각 피쳐가 예측 결과에 \"얼마나\" 기여하는지 설명합니다. 구체적으로 baseline에서 positive 클래스의 예측 확률 간의 편차로 정량화되며,\n",
"신규 데이터에 대한 예측 시에는 훈련 데이터에 대한 평균 예측과 다른 각 피쳐가 예측에 얼마나 기여하는지 정량화합니다."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b5fa52ca",
"metadata": {},
"outputs": [],
"source": [
"class AutogluonWrapper:\n",
" def __init__(self, predictor, feature_names):\n",
" self.ag_model = predictor\n",
" self.feature_names = feature_names\n",
" \n",
" def predict_proba(self, X):\n",
" if isinstance(X, pd.Series):\n",
" X = X.values.reshape(1,-1)\n",
" if not isinstance(X, pd.DataFrame):\n",
" X = pd.DataFrame(X, columns=self.feature_names)\n",
" return self.ag_model.predict_proba(X)"
]
},
{
"cell_type": "markdown",
"id": "7976039a",
"metadata": {},
"source": [
"피쳐의 baseline reference 값을 정의합니다. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "eb0640d3",
"metadata": {},
"outputs": [],
"source": [
"baseline = X_train.sample(100) # X_train.mode() could also be reasonable baseline for both numerical/categorical features rather than an entire dataset."
]
},
{
"cell_type": "markdown",
"id": "d90321a0",
"metadata": {},
"source": [
"AutoGluon 예측 결과를 설명하기 위해 Kernel SHAP 값을 반환하는 KernelExplainer를 생성합니다."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "75f4351a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline prediction: <=50K 0.582965\n",
" >50K 0.417035\n",
"dtype: float64\n"
]
}
],
"source": [
"ag_wrapper = AutogluonWrapper(predictor, X_train.columns)\n",
"explainer = shap.KernelExplainer(ag_wrapper.predict_proba, baseline)\n",
"print(\"Baseline prediction: \", np.mean(ag_wrapper.predict_proba(baseline))) # this is the same as explainer.expected_value"
]
},
{
"cell_type": "markdown",
"id": "c408c554",
"metadata": {},
"source": [
"### SHAP for single datapoint\n",
"\n",
"훈련 데이터셋 내의 임의의 데이터 포인트에 대해 SHAP을 plot해 보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f05733d8",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "91b366ced7c24c25b56889b52e3239d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ROW_INDEX = 0 # index of an example datapoint\n",
"single_datapoint = X_train.iloc[[ROW_INDEX]]\n",
"single_prediction = ag_wrapper.predict_proba(single_datapoint)\n",
"shap_values_single = explainer.shap_values(single_datapoint, nsamples=NSHAP_SAMPLES)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "db505e59",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shap.force_plot(explainer.expected_value[0], shap_values_single[0], X_train.iloc[ROW_INDEX,:])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9c6153a8",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shap.force_plot(explainer.expected_value[1], shap_values_single[1], X_train.iloc[ROW_INDEX,:])"
]
},
{
"cell_type": "markdown",
"id": "34f7781c",
"metadata": {},
"source": [
"### SHAP for dataset\n",
"\n",
"테스트 데이터셋의 모든 데이터 포인트에 대해서도 SHAP을 plot할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6b61817c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a74096f915fc41d882de9532589c4b91",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/50 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shap_values = explainer.shap_values(X_test, nsamples=NSHAP_SAMPLES)\n",
"shap.force_plot(explainer.expected_value[0], shap_values[0], X_test)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "109ebddd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shap.summary_plot(shap_values, X_test)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "91766657",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shap.summary_plot(shap_values[0], X_test)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8908346d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shap.dependence_plot(\"education-num\", shap_values[0], X_test)"
]
},
{
"cell_type": "markdown",
"id": "2a7d304f",
"metadata": {},
"source": [
"### Overall Feature Importance \n",
"\n",
"개별 예측을 설명하는 대신 각 피쳐가 AutoGluon의 일반적인 예측 정확도에 얼마나 기여하는지 알고 싶다면 Permutation Shuffling을 활용할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c4053299",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Computing feature importance via permutation shuffling for 14 features using 50 rows with 5 shuffle sets...\n",
"\t7.3s\t= Expected runtime (1.46s per shuffle set)\n",
"\t0.87s\t= Actual runtime (Completed 5 of 5 shuffle sets)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" importance \n",
" stddev \n",
" p_value \n",
" n \n",
" p99_high \n",
" p99_low \n",
" \n",
" \n",
" \n",
" \n",
" marital-status \n",
" 0.064 \n",
" 0.029665 \n",
" 0.004249 \n",
" 5 \n",
" 0.125080 \n",
" 0.002920 \n",
" \n",
" \n",
" education-num \n",
" 0.048 \n",
" 0.017889 \n",
" 0.001941 \n",
" 5 \n",
" 0.084833 \n",
" 0.011167 \n",
" \n",
" \n",
" capital-gain \n",
" 0.032 \n",
" 0.017889 \n",
" 0.008065 \n",
" 5 \n",
" 0.068833 \n",
" -0.004833 \n",
" \n",
" \n",
" capital-loss \n",
" 0.024 \n",
" 0.026077 \n",
" 0.054350 \n",
" 5 \n",
" 0.077693 \n",
" -0.029693 \n",
" \n",
" \n",
" age \n",
" 0.004 \n",
" 0.043359 \n",
" 0.423322 \n",
" 5 \n",
" 0.093277 \n",
" -0.085277 \n",
" \n",
" \n",
" relationship \n",
" 0.004 \n",
" 0.016733 \n",
" 0.310654 \n",
" 5 \n",
" 0.038454 \n",
" -0.030454 \n",
" \n",
" \n",
" workclass \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" education \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" race \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" native-country \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" fnlwgt \n",
" -0.008 \n",
" 0.010954 \n",
" 0.911096 \n",
" 5 \n",
" 0.014555 \n",
" -0.030555 \n",
" \n",
" \n",
" sex \n",
" -0.008 \n",
" 0.010954 \n",
" 0.911096 \n",
" 5 \n",
" 0.014555 \n",
" -0.030555 \n",
" \n",
" \n",
" hours-per-week \n",
" -0.008 \n",
" 0.030332 \n",
" 0.706475 \n",
" 5 \n",
" 0.054453 \n",
" -0.070453 \n",
" \n",
" \n",
" occupation \n",
" -0.024 \n",
" 0.008944 \n",
" 0.998059 \n",
" 5 \n",
" -0.005584 \n",
" -0.042416 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" importance stddev p_value n p99_high p99_low\n",
"marital-status 0.064 0.029665 0.004249 5 0.125080 0.002920\n",
"education-num 0.048 0.017889 0.001941 5 0.084833 0.011167\n",
"capital-gain 0.032 0.017889 0.008065 5 0.068833 -0.004833\n",
"capital-loss 0.024 0.026077 0.054350 5 0.077693 -0.029693\n",
"age 0.004 0.043359 0.423322 5 0.093277 -0.085277\n",
"relationship 0.004 0.016733 0.310654 5 0.038454 -0.030454\n",
"workclass 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"education 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"race 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"native-country 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"fnlwgt -0.008 0.010954 0.911096 5 0.014555 -0.030555\n",
"sex -0.008 0.010954 0.911096 5 0.014555 -0.030555\n",
"hours-per-week -0.008 0.030332 0.706475 5 0.054453 -0.070453\n",
"occupation -0.024 0.008944 0.998059 5 -0.005584 -0.042416"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor.feature_importance(test_data)"
]
},
{
"cell_type": "markdown",
"id": "cb467b92",
"metadata": {},
"source": [
" \n",
"\n",
"## 3. Multiclass Classification\n",
"다중(multi) 클래스 분류도 SHAP 적용이 가능합니다. 이번에는 개인 소득 대신 가족 관계(relationshop)를 예측하는 문제로 변경해서 훈련을 수행 후 SHAP을 확인해 보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d24c1b60",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" age \n",
" workclass \n",
" fnlwgt \n",
" education \n",
" education-num \n",
" marital-status \n",
" occupation \n",
" relationship \n",
" race \n",
" sex \n",
" capital-gain \n",
" capital-loss \n",
" hours-per-week \n",
" native-country \n",
" class \n",
" \n",
" \n",
" \n",
" \n",
" 26802 \n",
" 55 \n",
" Self-emp-not-inc \n",
" 319883 \n",
" Masters \n",
" 14 \n",
" Married-civ-spouse \n",
" Exec-managerial \n",
" Husband \n",
" White \n",
" Male \n",
" 4386 \n",
" 0 \n",
" 10 \n",
" ? \n",
" >50K \n",
" \n",
" \n",
" 19134 \n",
" 78 \n",
" ? \n",
" 167336 \n",
" HS-grad \n",
" 9 \n",
" Married-civ-spouse \n",
" ? \n",
" Husband \n",
" White \n",
" Male \n",
" 0 \n",
" 0 \n",
" 16 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 37431 \n",
" 36 \n",
" Private \n",
" 190350 \n",
" HS-grad \n",
" 9 \n",
" Never-married \n",
" Adm-clerical \n",
" Not-in-family \n",
" Black \n",
" Female \n",
" 0 \n",
" 0 \n",
" 40 \n",
" United-States \n",
" <=50K \n",
" \n",
" \n",
" 20173 \n",
" 25 \n",
" Self-emp-inc \n",
" 160261 \n",
" Bachelors \n",
" 13 \n",
" Never-married \n",
" Exec-managerial \n",
" Own-child \n",
" Asian-Pac-Islander \n",
" Male \n",
" 0 \n",
" 0 \n",
" 35 \n",
" Taiwan \n",
" <=50K \n",
" \n",
" \n",
" 3869 \n",
" 47 \n",
" Private \n",
" 216096 \n",
" Some-college \n",
" 10 \n",
" Married-spouse-absent \n",
" Exec-managerial \n",
" Unmarried \n",
" White \n",
" Female \n",
" 0 \n",
" 0 \n",
" 35 \n",
" Puerto-Rico \n",
" <=50K \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"26802 55 Self-emp-not-inc 319883 Masters 14 \n",
"19134 78 ? 167336 HS-grad 9 \n",
"37431 36 Private 190350 HS-grad 9 \n",
"20173 25 Self-emp-inc 160261 Bachelors 13 \n",
"3869 47 Private 216096 Some-college 10 \n",
"\n",
" marital-status occupation relationship \\\n",
"26802 Married-civ-spouse Exec-managerial Husband \n",
"19134 Married-civ-spouse ? Husband \n",
"37431 Never-married Adm-clerical Not-in-family \n",
"20173 Never-married Exec-managerial Own-child \n",
"3869 Married-spouse-absent Exec-managerial Unmarried \n",
"\n",
" race sex capital-gain capital-loss \\\n",
"26802 White Male 4386 0 \n",
"19134 White Male 0 0 \n",
"37431 Black Female 0 0 \n",
"20173 Asian-Pac-Islander Male 0 0 \n",
"3869 White Female 0 0 \n",
"\n",
" hours-per-week native-country class \n",
"26802 10 ? >50K \n",
"19134 16 United-States <=50K \n",
"37431 40 United-States <=50K \n",
"20173 35 Taiwan <=50K \n",
"3869 35 Puerto-Rico <=50K "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Possible classes: \n",
" Husband 234\n",
" Not-in-family 114\n",
" Own-child 78\n",
" Unmarried 49\n",
" Wife 15\n",
" Other-relative 10\n",
"Name: relationship, dtype: int64\n"
]
}
],
"source": [
"label = 'relationship'\n",
"\n",
"y_train = train_data[label]\n",
"y_test = test_data[label]\n",
"X_train = pd.DataFrame(train_data.drop(columns=[label]))\n",
"X_test = pd.DataFrame(test_data.drop(columns=[label]))\n",
"\n",
"display(train_data.head())\n",
"print(\"Possible classes: \\n\", train_data[label].value_counts())"
]
},
{
"cell_type": "markdown",
"id": "1a749f8a",
"metadata": {},
"source": [
"`problem_type`을 지정하지 않아도 AutoGluon에서 자동으로 처리하지만, 안전하게 `problem_type=multiclass`로 지정합니다."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "02b62ff8",
"metadata": {},
"outputs": [],
"source": [
"save_path = 'ag-03-explainable-ai-multiclass'\n",
"!rm -rf $save_path"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "de4d4923",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Beginning AutoGluon training ... Time limit = 20s\n",
"AutoGluon will save models to \"ag-03-explainable-ai-multiclass/\"\n",
"AutoGluon Version: 0.5.2\n",
"Python Version: 3.8.12\n",
"Operating System: Linux\n",
"Train Data Rows: 500\n",
"Train Data Columns: 14\n",
"Label Column: relationship\n",
"Preprocessing data ...\n",
"Train Data Class Count: 6\n",
"Using Feature Generators to preprocess the data ...\n",
"Fitting AutoMLPipelineFeatureGenerator...\n",
"\tAvailable Memory: 13576.74 MB\n",
"\tTrain Data (Original) Memory Usage: 0.29 MB (0.0% of available memory)\n",
"\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
"\tStage 1 Generators:\n",
"\t\tFitting AsTypeFeatureGenerator...\n",
"\t\t\tNote: Converting 2 features to boolean dtype as they only contain 2 unique values.\n",
"\tStage 2 Generators:\n",
"\t\tFitting FillNaFeatureGenerator...\n",
"\tStage 3 Generators:\n",
"\t\tFitting IdentityFeatureGenerator...\n",
"\t\tFitting CategoryFeatureGenerator...\n",
"\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
"\tStage 4 Generators:\n",
"\t\tFitting DropUniqueFeatureGenerator...\n",
"\tTypes of features in original data (raw dtype, special dtypes):\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'race', ...]\n",
"\tTypes of features in processed data (raw dtype, special dtypes):\n",
"\t\t('category', []) : 6 | ['workclass', 'education', 'marital-status', 'occupation', 'race', ...]\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('int', ['bool']) : 2 | ['sex', 'class']\n",
"\t0.1s = Fit runtime\n",
"\t14 features in original data used to generate 14 features in processed data.\n",
"\tTrain Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)\n",
"Data preprocessing and feature engineering runtime = 0.09s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
"\tTo change this, specify the eval_metric parameter of Predictor()\n",
"Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100\n",
"Fitting 13 L1 models ...\n",
"Fitting model: KNeighborsUnif ... Training model for up to 19.91s of the 19.91s of remaining time.\n",
"\t0.44\t = Validation score (accuracy)\n",
"\t0.0s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: KNeighborsDist ... Training model for up to 19.9s of the 19.9s of remaining time.\n",
"\t0.33\t = Validation score (accuracy)\n",
"\t0.0s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"Fitting model: NeuralNetFastAI ... Training model for up to 19.89s of the 19.89s of remaining time.\n",
"\t0.72\t = Validation score (accuracy)\n",
"\t0.82s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBMXT ... Training model for up to 19.05s of the 19.04s of remaining time.\n",
"\t0.79\t = Validation score (accuracy)\n",
"\t0.52s\t = Training runtime\n",
"\t0.05s\t = Validation runtime\n",
"Fitting model: LightGBM ... Training model for up to 18.46s of the 18.46s of remaining time.\n",
"\t0.79\t = Validation score (accuracy)\n",
"\t0.78s\t = Training runtime\n",
"\t0.03s\t = Validation runtime\n",
"Fitting model: RandomForestGini ... Training model for up to 17.63s of the 17.63s of remaining time.\n",
"\t0.75\t = Validation score (accuracy)\n",
"\t0.64s\t = Training runtime\n",
"\t0.06s\t = Validation runtime\n",
"Fitting model: RandomForestEntr ... Training model for up to 16.91s of the 16.91s of remaining time.\n",
"\t0.75\t = Validation score (accuracy)\n",
"\t0.69s\t = Training runtime\n",
"\t0.08s\t = Validation runtime\n",
"Fitting model: CatBoost ... Training model for up to 16.11s of the 16.11s of remaining time.\n",
"\tRan out of time, early stopping on iteration 312.\n",
"\t0.81\t = Validation score (accuracy)\n",
"\t16.08s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: ExtraTreesGini ... Training model for up to 0.02s of the 0.02s of remaining time.\n",
"\tWarning: Model is expected to require 1.0s to train, which exceeds the maximum time limit of 0.0s, skipping model...\n",
"\tTime limit exceeded... Skipping ExtraTreesGini.\n",
"Fitting model: WeightedEnsemble_L2 ... Training model for up to 19.91s of the -0.03s of remaining time.\n",
"\t0.81\t = Validation score (accuracy)\n",
"\t0.22s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"AutoGluon training complete, total runtime = 20.26s ... Best model: \"WeightedEnsemble_L2\"\n",
"TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-03-explainable-ai-multiclass/\")\n"
]
}
],
"source": [
"predictor_multi = TabularPredictor(label=label, path=save_path, problem_type='multiclass').fit(train_data, time_limit=20)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e3fdd66d",
"metadata": {},
"outputs": [],
"source": [
"baseline = X_train.sample(100) # X_train.mode() could also be reasonable baseline for both numerical/categorical features rather than an entire dataset.\n",
"\n",
"ag_wrapper = AutogluonWrapper(predictor_multi, X_train.columns)\n",
"explainer = shap.KernelExplainer(ag_wrapper.predict_proba, baseline)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "fdf21cbf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" \n",
" \n",
" \n",
" \n",
" Husband \n",
" 0.491491 \n",
" \n",
" \n",
" Not-in-family \n",
" 0.213104 \n",
" \n",
" \n",
" Other-relative \n",
" 0.011575 \n",
" \n",
" \n",
" Own-child \n",
" 0.191880 \n",
" \n",
" \n",
" Unmarried \n",
" 0.072857 \n",
" \n",
" \n",
" Wife \n",
" 0.019095 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0\n",
" Husband 0.491491\n",
" Not-in-family 0.213104\n",
" Other-relative 0.011575\n",
" Own-child 0.191880\n",
" Unmarried 0.072857\n",
" Wife 0.019095"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(np.mean(ag_wrapper.predict_proba(baseline),axis=0))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a8fb0561",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class Info: \n",
" [' Husband', ' Not-in-family', ' Other-relative', ' Own-child', ' Unmarried', ' Wife']\n"
]
},
{
"data": {
"text/html": [
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(\"Class Info: \\n\", predictor_multi.class_labels)\n",
"\n",
"NSHAP_SAMPLES = 10 # how many samples to use to approximate each Shapely value, larger values will be slower\n",
"shap.initjs()"
]
},
{
"cell_type": "markdown",
"id": "8294338d",
"metadata": {},
"source": [
"class 중 Not-in-family에 대해서 SHAP을 plot해 보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "eb36161d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12ba0b3896fe44e89e5a2c6109636654",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shapely values: \n",
" {' Husband': array([[0. , 0. , 0. , 0. , 0. ,\n",
" 0.38148441, 0. , 0. , 0.09544831, 0. ,\n",
" 0. , 0. , 0. , 0. ]]), ' Not-in-family': array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" -0.20806026, 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. ]]), ' Other-relative': array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" -0.00548243, 0. , 0. , -0.00209195, 0. ,\n",
" 0. , 0. , 0. , 0. ]]), ' Own-child': array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" -0.1856742, 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. ]]), ' Unmarried': array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" -0.05492021, 0. , 0. , -0.01444695, 0. ,\n",
" 0. , 0. , 0. , 0. ]]), ' Wife': array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" 0.07131596, 0. , 0. , -0.0775727 , 0. ,\n",
" 0. , 0. , 0. , 0. ]])}\n",
"Force_plot for class: Not-in-family\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ROW_INDEX = 0 # index of an example datapoint\n",
"class_of_interest = ' Not-in-family' # can be any value in set(y_train)\n",
"class_index = predictor_multi.class_labels.index(class_of_interest)\n",
"\n",
"single_datapoint = X_train.iloc[[ROW_INDEX]]\n",
"single_prediction = ag_wrapper.predict_proba(single_datapoint)\n",
"\n",
"shap_values_single = explainer.shap_values(single_datapoint, nsamples=NSHAP_SAMPLES)\n",
"print(\"Shapely values: \\n\", {predictor_multi.class_labels[i]:shap_values_single[i] for i in range(len(predictor_multi.class_labels))})\n",
"\n",
"print(f\"Force_plot for class: {class_of_interest}\")\n",
"shap.force_plot(explainer.expected_value[class_index], shap_values_single[class_index], X_train.iloc[ROW_INDEX,:])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "f0bec6a3",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb75a241cd944135852b48f4a08682d1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/50 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Force_plot for class: Not-in-family\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shap_values = explainer.shap_values(X_test, nsamples=NSHAP_SAMPLES)\n",
"\n",
"print(f\"Force_plot for class: {class_of_interest}\")\n",
"shap.force_plot(explainer.expected_value[class_index], shap_values[class_index], X_test)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "1d7069b6",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'Class 0': ' Husband', 'Class 1': ' Not-in-family', 'Class 2': ' Other-relative', 'Class 3': ' Own-child', 'Class 4': ' Unmarried', 'Class 5': ' Wife'}\n"
]
}
],
"source": [
"shap.summary_plot(shap_values, X_test)\n",
"print({\"Class \"+str(i) : predictor_multi.class_labels[i] for i in range(len(predictor_multi.class_labels))})"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "244b48e1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dependence_plot for class: Not-in-family and for feature: marital-status \n",
"\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dependence_feature = \"marital-status\"\n",
"print(f\"Dependence_plot for class: {class_of_interest} and for feature: {dependence_feature} \\n\")\n",
"\n",
"shap.dependence_plot(dependence_feature, shap_values[class_index], X_test)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "250efa4d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Computing feature importance via permutation shuffling for 14 features using 50 rows with 5 shuffle sets...\n",
"\t7.57s\t= Expected runtime (1.51s per shuffle set)\n",
"\t0.87s\t= Actual runtime (Completed 5 of 5 shuffle sets)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" importance \n",
" stddev \n",
" p_value \n",
" n \n",
" p99_high \n",
" p99_low \n",
" \n",
" \n",
" \n",
" \n",
" marital-status \n",
" 0.064 \n",
" 0.029665 \n",
" 0.004249 \n",
" 5 \n",
" 0.125080 \n",
" 0.002920 \n",
" \n",
" \n",
" education-num \n",
" 0.048 \n",
" 0.017889 \n",
" 0.001941 \n",
" 5 \n",
" 0.084833 \n",
" 0.011167 \n",
" \n",
" \n",
" capital-gain \n",
" 0.032 \n",
" 0.017889 \n",
" 0.008065 \n",
" 5 \n",
" 0.068833 \n",
" -0.004833 \n",
" \n",
" \n",
" capital-loss \n",
" 0.024 \n",
" 0.026077 \n",
" 0.054350 \n",
" 5 \n",
" 0.077693 \n",
" -0.029693 \n",
" \n",
" \n",
" age \n",
" 0.004 \n",
" 0.043359 \n",
" 0.423322 \n",
" 5 \n",
" 0.093277 \n",
" -0.085277 \n",
" \n",
" \n",
" relationship \n",
" 0.004 \n",
" 0.016733 \n",
" 0.310654 \n",
" 5 \n",
" 0.038454 \n",
" -0.030454 \n",
" \n",
" \n",
" workclass \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" education \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" race \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" native-country \n",
" 0.000 \n",
" 0.000000 \n",
" 0.500000 \n",
" 5 \n",
" 0.000000 \n",
" 0.000000 \n",
" \n",
" \n",
" fnlwgt \n",
" -0.008 \n",
" 0.010954 \n",
" 0.911096 \n",
" 5 \n",
" 0.014555 \n",
" -0.030555 \n",
" \n",
" \n",
" sex \n",
" -0.008 \n",
" 0.010954 \n",
" 0.911096 \n",
" 5 \n",
" 0.014555 \n",
" -0.030555 \n",
" \n",
" \n",
" hours-per-week \n",
" -0.008 \n",
" 0.030332 \n",
" 0.706475 \n",
" 5 \n",
" 0.054453 \n",
" -0.070453 \n",
" \n",
" \n",
" occupation \n",
" -0.024 \n",
" 0.008944 \n",
" 0.998059 \n",
" 5 \n",
" -0.005584 \n",
" -0.042416 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" importance stddev p_value n p99_high p99_low\n",
"marital-status 0.064 0.029665 0.004249 5 0.125080 0.002920\n",
"education-num 0.048 0.017889 0.001941 5 0.084833 0.011167\n",
"capital-gain 0.032 0.017889 0.008065 5 0.068833 -0.004833\n",
"capital-loss 0.024 0.026077 0.054350 5 0.077693 -0.029693\n",
"age 0.004 0.043359 0.423322 5 0.093277 -0.085277\n",
"relationship 0.004 0.016733 0.310654 5 0.038454 -0.030454\n",
"workclass 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"education 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"race 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"native-country 0.000 0.000000 0.500000 5 0.000000 0.000000\n",
"fnlwgt -0.008 0.010954 0.911096 5 0.014555 -0.030555\n",
"sex -0.008 0.010954 0.911096 5 0.014555 -0.030555\n",
"hours-per-week -0.008 0.030332 0.706475 5 0.054453 -0.070453\n",
"occupation -0.024 0.008944 0.998059 5 -0.005584 -0.042416"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor.feature_importance(test_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}