# Evaluate results.
This notebook creates visualisations of the violations created at previous steps

In [None]:
import boto3
import sagemaker
import json
import os
from utils import get_aws_profile_name, get_aws_iam_role, get_trial_name
import generate_altered
from generate_altered import TransformsEnum
import pandas as pd

LOCAL_EXECUTION = True

if LOCAL_EXECUTION:
    sess = boto3.Session(profile_name=get_aws_profile_name())
    sm = sess.client("sagemaker")
    iam = sess.client('iam')
    role = iam.get_role(RoleName=get_aws_iam_role())['Role']['Arn']
else:
    sess = boto3.Session()
    sm = sess.client("sagemaker")
    role = sagemaker.get_execution_role()

sagemaker_session = sagemaker.Session(boto_session=sess)
bucket = sagemaker_session.default_bucket()
prefix = "model-monitor-bring-your-own-model/"
region = sess.region_name

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
from botocore.exceptions import ClientError


### Helpers

In [None]:
def get_latest_trial_component(transform:TransformsEnum):
    trial_name = get_trial_name(transform.__name__.replace("_","-"))
    component = sm.describe_trial_component(TrialComponentName=trial_name)
    return component

In [None]:
def plot_baseline_data_violations(violations):
    # Below function may not display properly for multiple metrics with variable thresholds
    metrics = [d["feature_name"] for d in violations if d['constraint_check_type']=='baseline_drift_check']
    if len(metrics) < 1 : print("No feature baseline violation found - exiting plot_baseline_data_violations"); return
    data = [[],[]]
    for i, metric in enumerate(metrics):
        text = [d['description'] for d in violations if d.get("feature_name")==metric][0]
        values = re.findall(r'[\D\.]*([\d\.]+)', text)
        values = [float(v) for v in values]

        data[0].append(values[0])
        data[1].append(values[1])


    X = np.arange(len(metrics))
    threshold = data[1]
    values = np.array(data[0])
    x = range(len(values))

    # split it up
    above_threshold = np.maximum(values - threshold, 0)
    below_threshold = np.minimum(values, threshold)

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')

    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.bar(x, below_threshold, 0.35, color="g")
    ax.bar(x, above_threshold, 0.35, color="r", bottom=below_threshold, label="observed exceeding threshold")
    ax.plot([-1., 1.], [0.1, 0.1], "k--")

    ax.set_xlabel('Metric')
    ax.set_ylabel('Metric value')
    ax.set_title('Feature baseline drift was detected')
    ax.set_xticks(X + bar_width / 2)
    ax.set_xticklabels(metrics)
    ax.legend(bbox_to_anchor=(1, 1) )
    ax.text(-1., threshold[0]*1.02, ' threshold')
    plt.show()


In [None]:
def plot_model_performance_violations(violations):
    metrics = ["precision", "auc", "accuracy"]
    data = [[],[]]
    for i, metric in enumerate(metrics):
        text = [d['description'] for d in violations if d.get("metric_name")==metric][0]
        values = re.findall(r'[\D\.]*([\d\.]+)', text)
        values = [float(v) for v in values]

        data[0].append(values[0])
        data[1].append(values[2])

    X = np.arange(len(metrics))

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')

    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.bar(X , data[0], color = 'b', width = bar_width, label="observed")
    ax.bar(X + bar_width, data[1], color = 'g', width = bar_width, label="baseline")

    ax.set_xlabel('Metric')
    ax.set_ylabel('Metric value')
    ax.set_title('Model performance deterioration was detected')
    ax.set_xticks(X + bar_width / 2)
    ax.set_xticklabels(metrics)
    ax.legend(bbox_to_anchor=(1, 1) )

    plt.show()


In [None]:
def plot_label_drift():
    original_df = pd.read_csv("data/train.csv")
    altered_df = TransformsEnum.LABEL_DRIFT(generate_altered.load_data())
    print(original_df.credit_risk.mean())
    print(altered_df.credit_risk.mean())

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')
    X = np.arange(1)
    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.bar(X , altered_df.credit_risk.mean()*100, color = 'b', width = bar_width, label="observed")
    ax.bar(X + bar_width, original_df.credit_risk.mean()*100, color = 'g', width = bar_width, label="baseline")

    ax.set_xlabel('Metric')
    ax.set_ylabel('Metric value')
    ax.set_title(f'Percentage of credit accepted')
    ax.set_xticks(X + bar_width / 2)
    ax.set_xticklabels(["Percentage of credit accepted"])
    plt.xlim([-.5, 1])
    ax.legend(bbox_to_anchor=(1, 1) )



In [None]:
def plot_bias_violations(violations, focus_metric=None):
    # for explanation of metrics see: https://pages.awscloud.com/rs/112-TZM-766/images/Fairness.Measures.for.Machine.Learning.in.Finance.pdf 
    metrics = [d["metric_name"] for d in violations if d['constraint_check_type']=='bias_drift_check']
    data = [[],[]]
    for i, metric in enumerate(metrics):
        text = [d['description'] for d in violations if d.get("metric_name")==metric][0]
        values = re.findall(r'[\D\.]*([\d\.]+)', text)
        values = [float(v) for v in values]

        data[0].append(values[0])
        data[1].append(values[1])

    X = np.arange(len(metrics))

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')

    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.bar(X , data[0], color = 'b', width = bar_width, label="observed")
    ax.bar(X + bar_width, data[1], color = 'g', width = bar_width, label="baseline")

    ax.set_xlabel('Metric')
    ax.set_ylabel('Metric value')
    ax.set_title('Bias drift check triggered for the following')
    ax.set_xticks(X + bar_width / 2)
    ax.set_xticklabels(metrics)
    ax.legend(bbox_to_anchor=(1, 1) )

    plt.show()

    if focus_metric:
        loc = metrics.index(focus_metric)
        fig = plt.figure()
        plt.style.use('seaborn-darkgrid')
        X = np.arange(1)
        ax = fig.add_axes([0,0,1,1])
        bar_width = .25
        ax.bar(X , data[0][loc], color = 'b', width = bar_width, label="observed")
        ax.bar(X + bar_width, data[1][loc], color = 'g', width = bar_width, label="baseline")

        ax.set_xlabel('Metric')
        ax.set_ylabel('Metric value')
        ax.set_title(f'Bias drift check triggered for the metric {focus_metric}')
        ax.set_xticks(X + bar_width / 2)
        ax.set_xticklabels([focus_metric])
        plt.xlim([-.5, 1])
        ax.legend(bbox_to_anchor=(1, 1) )

        plt.show()


In [None]:
import utils

def plot_model_explainability_violations(violations, job_output_uri):
    
    metrics = ["shap"]
    data = [[],[]]
    features = []
    for i, metric in enumerate(metrics):
        text = [d['description'] for d in violations if d.get("metric_name")==metric][0]
        values = re.findall(r'[\D\.]*([\d\.]+)', text)
        values = [float(v) for v in values]

        ndcg = values[0]
        baseline_ndcg = values[1]

    print(f"NDCG (Normalized Discounted Cumulative Gain) computed was {ndcg}")

    analysis_uri = job_output_uri.replace("constraint_violations", "analysis")
    analysis_json = json.loads(sagemaker.s3.S3Downloader().read_file(analysis_uri, sagemaker_session=sagemaker_session))["explanations"]["kernel_shap"]["label0"]["global_shap_values"]

    baseline_analysis_uri = utils.get_baseline_uri("model-explainability-analysis")
    baseline_analysis_json = json.loads(sagemaker.s3.S3Downloader().read_file(baseline_analysis_uri, sagemaker_session=sagemaker_session))["explanations"]["kernel_shap"]["label0"]["global_shap_values"]


    for key in analysis_json.keys():
        data[0].append(analysis_json.get(key))
        data[1].append(baseline_analysis_json.get(key))
        features.append(key)

    X = np.arange(len(features))

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')

    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.barh(X , data[0], color = 'b', height = bar_width, label="observed")
    ax.barh(X + bar_width, data[1], color = 'g', height = bar_width, label="baseline")

    ax.set_ylabel('Features')
    ax.set_xlabel('Feature importance')
    ax.set_title('Feature importance differences in observed and baseline datasets')
    ax.set_yticks(X + bar_width / 2)
    ax.set_yticklabels(features)
    ax.legend(bbox_to_anchor=(1, 1) )

    plt.show()

    X = np.arange(len(features[:5]))

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')

    ax = fig.add_axes([0,0,1,1])
    bar_width = .25
    ax.barh(X , data[0][:5], color = 'b', height = bar_width, label="observed")
    ax.barh(X + bar_width, data[1][:5], color = 'g', height = bar_width, label="baseline")

    ax.set_ylabel('Features')
    ax.set_xlabel('Feature importance')
    ax.set_title('Feature importance differences in observed and baseline datasets')
    ax.set_yticks(X + bar_width / 2)
    ax.set_yticklabels(features[:5])
    ax.legend(bbox_to_anchor=(1, 1) )

    plt.show()


In [None]:
def plot_feature_drift(transform=TransformsEnum.FEATURE_DRIFT):
    original_df = pd.read_csv("data/train.csv")
    altered_df = transform(generate_altered.load_data())
    print(original_df.credit_risk.mean())
    print(altered_df.credit_risk.mean())

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')
    # X = np.arange(3)
    ax = fig.add_axes([0,0,1,1])

    ax.hist(original_df["duration"], alpha=0.5, label="original")
    ax.hist(altered_df["duration"], alpha=0.5, label="observed")

    ax.set_xlabel('Feature: Duration')
    ax.set_ylabel('frequency')
    ax.set_title(f'Distribution of duration')
    
    plt.legend()
    plt.show()



### Concept drift

In [None]:
# Violation link acquired from experiment tracking tool and copy&pasted for convinience 
component = get_latest_trial_component(TransformsEnum.CONCEPT_DRIFT)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")

plot_model_performance_violations(violations)
plot_bias_violations(violations)

In this scenario we can see that due to Concept Drift, the model performance of the model was impacted and our precision, along with other metrics had a considerable impact

### Label Drift

In [None]:
component = get_latest_trial_component(TransformsEnum.LABEL_DRIFT)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")

plot_model_performance_violations(violations)
plot_label_drift()

### Bias drift

In [None]:
component = get_latest_trial_component(TransformsEnum.BIAS_DRIFT)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")

plot_bias_violations(violations, focus_metric="DPPL")

In [None]:
def plot_bias_dppl_expanded():
    original_df = pd.read_csv("data/train.csv")
    altered_df = TransformsEnum.BIAS_DRIFT(generate_altered.load_data())
    print(original_df.credit_risk.mean())
    print(altered_df.credit_risk.mean())

    fig = plt.figure()
    plt.style.use('seaborn-darkgrid')
    X = np.arange(3)
    ax = fig.add_axes([0,0,2,1])
    bar_width = .25
    data_observed = [
        altered_df.loc[altered_df.foreign_worker == 1].credit_risk.mean()*100, 
        altered_df.loc[altered_df.foreign_worker != 1].credit_risk.mean()*100, 
        altered_df.credit_risk.mean()*100
        ]
    data_baseline = [
        original_df.loc[original_df.foreign_worker == 1].credit_risk.mean()*100,
        original_df.loc[original_df.foreign_worker != 1].credit_risk.mean()*100,
        original_df.credit_risk.mean()*100
        ]
    ax.bar(X , data_observed , color = 'b', width = bar_width, label="observed")
    ax.bar(X + bar_width, data_baseline, color = 'g', width = bar_width, label="baseline")

    ax.set_xlabel('Metric')
    ax.set_ylabel('Metric value')
    ax.set_title(f'Percentage of credit accepted')
    ax.set_xticks(X + bar_width / 2)
    ax.set_xticklabels(["Percentage of Accepted Foreign Workers", "Percentage of Accepted non-Foreign Workers", "Percentage of All Accepted Workers"])
    ax.legend(bbox_to_anchor=(1, 1) )


plot_bias_dppl_expanded()

### Feature Drift

In [None]:
component = get_latest_trial_component(TransformsEnum.FEATURE_DRIFT)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")

plot_baseline_data_violations(violations)
plot_feature_drift()

### Feature Drift - Systematic issue

In [None]:
component = get_latest_trial_component(TransformsEnum.FEATURE_DRIFT_SYSTEMATIC)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")

plot_baseline_data_violations(violations)
plot_feature_drift(TransformsEnum.FEATURE_DRIFT_SYSTEMATIC)


### Explainability Drift

In [None]:
component = get_latest_trial_component(TransformsEnum.EXPLAINABILITY_DRIFT)

violations_uris = [
    component['Parameters']['data-quality-violoations']['StringValue'],
    component['Parameters']['model-quality-violoations']['StringValue'],
    component['Parameters']['model-bias-violoations']['StringValue'],
    component['Parameters']['model-explainability-violoations']['StringValue']
    ]

violations = []
for violations_uri in violations_uris:
    try:
        viol = json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session))

        violations += viol["violations"]
    except ClientError as ex:
        if ex.response['Error']['Code'] == 'NoSuchKey':
            print(f"No violation file {violations_uri} found ")


plot_model_explainability_violations(violations, component['Parameters']['model-explainability-violoations']['StringValue'])

### End