In [None]:
import autogluon as ag
from autogluon.tabular import TabularDataset, TabularPredictor

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import collections
from sklearn.model_selection import StratifiedKFold, train_test_split, ShuffleSplit
from sklearn.metrics import (confusion_matrix, classification_report, accuracy_score, roc_curve, average_precision_score,
 precision_recall_curve, precision_score, recall_score, f1_score, matthews_corrcoef, auc)
try:
 from joblib import dump, load
except ImportError:
 from sklearn.externals.joblib import dump, load

 
def plot_roc_curve(y_true, y_score, is_single_fig=False):
 """
 Plot ROC Curve and show AUROC score
 """ 
 fpr, tpr, _ = roc_curve(y_true, y_score)
 roc_auc = auc(fpr, tpr)
 plt.title('AUROC = {:.4f}'.format(roc_auc))
 plt.plot(fpr, tpr, 'b')
 plt.plot([0,1], [0,1], 'r--')
 plt.xlim([-0.05,1.05])
 plt.ylim([-0.05,1.05])
 plt.ylabel('TPR(True Positive Rate)')
 plt.xlabel('FPR(False Positive Rate)')
 if is_single_fig:
 plt.show()
 
def plot_pr_curve(y_true, y_score, is_single_fig=False):
 """
 Plot Precision Recall Curve and show AUPRC score
 """
 prec, rec, thresh = precision_recall_curve(y_true, y_score)
 avg_prec = average_precision_score(y_true, y_score)
 plt.title('AUPRC = {:.4f}'.format(avg_prec))
 plt.step(rec, prec, color='b', alpha=0.2, where='post')
 plt.fill_between(rec, prec, step='post', alpha=0.2, color='b')
 plt.plot(rec, prec, 'b')
 plt.xlim([-0.05,1.05])
 plt.ylim([-0.05,1.05])
 plt.ylabel('Precision')
 plt.xlabel('Recall')
 if is_single_fig:
 plt.show()

def plot_conf_mtx(y_true, y_score, thresh=0.5, class_labels=['0','1'], is_single_fig=False):
 """
 Plot Confusion matrix
 """ 
 y_pred = np.where(y_score >= thresh, 1, 0)
 print("confusion matrix (cutoff={})".format(thresh))
 print(classification_report(y_true, y_pred, target_names=class_labels))
 conf_mtx = confusion_matrix(y_true, y_pred)
 sns.heatmap(conf_mtx, xticklabels=class_labels, yticklabels=class_labels, annot=True, fmt='d')
 plt.title('Confusion Matrix')
 plt.ylabel('True Class')
 plt.xlabel('Predicted Class')
 if is_single_fig:
 plt.show()
 
def prob_barplot(y_score, bins=np.arange(0.0, 1.11, 0.1), right=False, filename=None, figsize=(10,4), is_single_fig=False):
 """
 Plot barplot by binning predicted scores ranging from 0 to 1
 """ 
 c = pd.cut(y_score, bins, right=right)
 counts = c.value_counts()
 percents = 100. * counts / len(c)
 percents.plot.bar(rot=0, figsize=figsize)
 plt.title('Histogram of score')
 print(percents)
 if filename is not None:
 plt.savefig('{}.png'.format(filename)) 
 if is_single_fig:
 plt.show()
 
def show_evals(y_true, y_score, thresh=0.5):
 """
 All-in-one function for evaluation. 
 """ 
 plt.figure(figsize=(14,4))
 plt.subplot(1,3,1)
 plot_roc_curve(y_true, y_score)
 plt.subplot(1,3,2) 
 plot_pr_curve(y_true, y_score)
 plt.subplot(1,3,3) 
 plot_conf_mtx(y_true, y_score, thresh) 
 plt.show()


In [None]:
MODEL_DIR = 'model_pkg'
S3_MODEL_PATH = '[YOUR S3 MODEL PATH]' # YOUR S3 MODEL PATH
TEST_FILE = 'test.csv' # YOUR TEST CSV FILE
LABEL_COLUMN = 'label' # YOUR TARGET COLUMN
THRESH = 0.5

In [None]:
!aws s3 cp {S3_MODEL_PATH} .
!rm -rf {MODEL_DIR}
!mkdir {MODEL_DIR}
!tar -xzvf model.tar.gz -C {MODEL_DIR}

In [None]:
test_data = TabularDataset(file_path=TEST_FILE)
y_test = test_data[LABEL_COLUMN] # values to predict
test_data_nolab = test_data.drop(labels=[LABEL_COLUMN],axis=1) # delete label column to prove we're not cheating

In [None]:
predictor = TabularPredictor(MODEL_DIR) 
y_prob = predictor.predict_proba(test_data_nolab)
y_pred = y_prob > THRESH

In [None]:
perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True)

In [None]:
perf["confusion_matrix"]

In [None]:
results = predictor.fit_summary()

In [None]:
show_evals(y_test, y_prob)

In [None]:
%%time

# Feature importance
featimp = predictor.feature_importance(test_data)
fig, ax = plt.subplots(figsize=(12,5))
plot = sns.barplot(x=featimp.index, y=featimp.values)
ax.set_title('Feature Importance')
plot.set_xticklabels(plot.get_xticklabels(), rotation='vertical')

plt.show()