{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import json\n",
    "import os\n",
    "from utils import get_aws_profile_name, get_aws_iam_role\n",
    "from datetime import datetime \n",
    "\n",
    "LOCAL_EXECUTION = True\n",
    "\n",
    "if LOCAL_EXECUTION:\n",
    "    sess = boto3.Session(profile_name=get_aws_profile_name())\n",
    "    sm = sess.client(\"sagemaker\")\n",
    "    iam = sess.client('iam')\n",
    "    role = iam.get_role(RoleName=get_aws_iam_role())['Role']['Arn']\n",
    "else:\n",
    "    sess = boto3.Session()\n",
    "    sm = sess.client(\"sagemaker\")\n",
    "    role = sagemaker.get_execution_role()\n",
    "\n",
    "sagemaker_session = sagemaker.Session(boto_session=sess)\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "prefix = \"model-monitor-bring-your-own-model/\"\n",
    "region = sess.region_name\n",
    "\n",
    "output_uri = \"s3://{}/{}{}\".format(bucket, prefix, \"results-model-quality-\"+datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.processing import (\n",
    "    ProcessingInput,\n",
    "    ProcessingOutput,\n",
    "    Processor,\n",
    ")\n",
    "from sagemaker.model_monitor.dataset_format import DatasetFormat\n",
    "from utils import get_baseline_uri, get_dataset_uri\n",
    "\n",
    "\n",
    "model_monitor_container_uri = sagemaker.image_uris.retrieve(\n",
    "            framework=\"model-monitor\",\n",
    "            region=region,\n",
    "            version=\"latest\",\n",
    "        )\n",
    "\n",
    "        # Create the baseline job using\n",
    "dataset_format = DatasetFormat.csv()\n",
    "\n",
    "env = {\n",
    "    \"dataset_format\": json.dumps(dataset_format),\n",
    "    \"dataset_source\": \"/opt/ml/processing/input/baseline_dataset_input\",\n",
    "    \"output_path\": \"/opt/ml/processing/output\",\n",
    "    \"publish_cloudwatch_metrics\": \"Disabled\",\n",
    "    \"analysis_type\":\"MODEL_QUALITY\",\n",
    "    \"problem_type\":'BinaryClassification',\n",
    "    \"inference_attribute\": \"prediction\", # The column in the dataset that contains predictions.\n",
    "    \"probability_attribute\": \"prediction_probability\", # The column in the dataset that contains probabilities.\n",
    "    \"ground_truth_attribute\": \"credit_risk\",\n",
    "    \"baseline_constraints\": \"/opt/ml/processing/baseline/constraints/constraints.json\",\n",
    "    \"baseline_statistics\": \"/opt/ml/processing/baseline/stats/statistics.json\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.processing import (\n",
    "    ProcessingInput,\n",
    "    ProcessingOutput,\n",
    "    Processor,\n",
    ")\n",
    "\n",
    "mode_quality_monitor_analyzer = Processor(\n",
    "    image_uri=model_monitor_container_uri,\n",
    "    role=role,\n",
    "    instance_count=1,\n",
    "    instance_type='ml.m5.xlarge',\n",
    "    base_job_name=f\"model-monitor-byom\",\n",
    "    sagemaker_session=sagemaker_session,\n",
    "    max_runtime_in_seconds=1800,\n",
    "    env=env,\n",
    ")\n",
    "\n",
    "mode_quality_monitor_analyzer.run(\n",
    "    inputs=[ProcessingInput(\n",
    "                source=get_dataset_uri('model-quality-modified-data'),\n",
    "                destination=\"/opt/ml/processing/input/baseline_dataset_input\",\n",
    "                input_name=\"baseline_dataset_input\",),\n",
    "            ProcessingInput(\n",
    "                source=get_baseline_uri('model-quality-constraints'),\n",
    "                destination=\"/opt/ml/processing/baseline/constraints\",\n",
    "                input_name=\"constraints\",\n",
    "                ),\n",
    "            ProcessingInput(\n",
    "                source=get_baseline_uri('model-quality-statistics'),\n",
    "                destination=\"/opt/ml/processing/baseline/stats\",\n",
    "                input_name=\"baseline\",\n",
    "                ),\n",
    "            ],\n",
    "    outputs=[\n",
    "        ProcessingOutput(\n",
    "                    source=\"/opt/ml/processing/output\",\n",
    "                    output_name=\"monitoring_output\",\n",
    "                    destination=output_uri,\n",
    "                )\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import save_baseline\n",
    "violations_uri = mode_quality_monitor_analyzer.latest_job.outputs[0].destination + '/constraint_violations.json'\n",
    "save_baseline('model-quality-violoations', violations_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    print(json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session)))\n",
    "except ClientError as ex:\n",
    "        if ex.response['Error']['Code'] == 'NoSuchKey':\n",
    "            print(\"No violation file found\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('general')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "4852bb1f7cd44f51326f23dc402ead6dde438dc19e87d2e1ec37a0afdae1dc27"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}