{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "import json\n", "import os\n", "from utils import get_aws_profile_name, get_aws_iam_role\n", "from datetime import datetime \n", "\n", "LOCAL_EXECUTION = True\n", "\n", "if LOCAL_EXECUTION:\n", " sess = boto3.Session(profile_name=get_aws_profile_name())\n", " sm = sess.client(\"sagemaker\")\n", " iam = sess.client('iam')\n", " role = iam.get_role(RoleName=get_aws_iam_role())['Role']['Arn']\n", "else:\n", " sess = boto3.Session()\n", " sm = sess.client(\"sagemaker\")\n", " role = sagemaker.get_execution_role()\n", "\n", "sagemaker_session = sagemaker.Session(boto_session=sess)\n", "bucket = sagemaker_session.default_bucket()\n", "prefix = \"model-monitor-bring-your-own-model/\"\n", "region = sess.region_name\n", "\n", "output_uri = \"s3://{}/{}{}\".format(bucket, prefix, \"results-model-quality-\"+datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import (\n", " ProcessingInput,\n", " ProcessingOutput,\n", " Processor,\n", ")\n", "from sagemaker.model_monitor.dataset_format import DatasetFormat\n", "from utils import get_baseline_uri, get_dataset_uri\n", "\n", "\n", "model_monitor_container_uri = sagemaker.image_uris.retrieve(\n", " framework=\"model-monitor\",\n", " region=region,\n", " version=\"latest\",\n", " )\n", "\n", " # Create the baseline job using\n", "dataset_format = DatasetFormat.csv()\n", "\n", "env = {\n", " \"dataset_format\": json.dumps(dataset_format),\n", " \"dataset_source\": \"/opt/ml/processing/input/baseline_dataset_input\",\n", " \"output_path\": \"/opt/ml/processing/output\",\n", " \"publish_cloudwatch_metrics\": \"Disabled\",\n", " \"analysis_type\":\"MODEL_QUALITY\",\n", " \"problem_type\":'BinaryClassification',\n", " \"inference_attribute\": \"prediction\", # The column in the dataset that contains predictions.\n", " \"probability_attribute\": \"prediction_probability\", # The column in the dataset that contains probabilities.\n", " \"ground_truth_attribute\": \"credit_risk\",\n", " \"baseline_constraints\": \"/opt/ml/processing/baseline/constraints/constraints.json\",\n", " \"baseline_statistics\": \"/opt/ml/processing/baseline/stats/statistics.json\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import (\n", " ProcessingInput,\n", " ProcessingOutput,\n", " Processor,\n", ")\n", "\n", "mode_quality_monitor_analyzer = Processor(\n", " image_uri=model_monitor_container_uri,\n", " role=role,\n", " instance_count=1,\n", " instance_type='ml.m5.xlarge',\n", " base_job_name=f\"model-monitor-byom\",\n", " sagemaker_session=sagemaker_session,\n", " max_runtime_in_seconds=1800,\n", " env=env,\n", ")\n", "\n", "mode_quality_monitor_analyzer.run(\n", " inputs=[ProcessingInput(\n", " source=get_dataset_uri('model-quality-modified-data'),\n", " destination=\"/opt/ml/processing/input/baseline_dataset_input\",\n", " input_name=\"baseline_dataset_input\",),\n", " ProcessingInput(\n", " source=get_baseline_uri('model-quality-constraints'),\n", " destination=\"/opt/ml/processing/baseline/constraints\",\n", " input_name=\"constraints\",\n", " ),\n", " ProcessingInput(\n", " source=get_baseline_uri('model-quality-statistics'),\n", " destination=\"/opt/ml/processing/baseline/stats\",\n", " input_name=\"baseline\",\n", " ),\n", " ],\n", " outputs=[\n", " ProcessingOutput(\n", " source=\"/opt/ml/processing/output\",\n", " output_name=\"monitoring_output\",\n", " destination=output_uri,\n", " )\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from utils import save_baseline\n", "violations_uri = mode_quality_monitor_analyzer.latest_job.outputs[0].destination + '/constraint_violations.json'\n", "save_baseline('model-quality-violoations', violations_uri)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " print(json.loads(sagemaker.s3.S3Downloader().read_file(violations_uri, sagemaker_session=sagemaker_session)))\n", "except ClientError as ex:\n", " if ex.response['Error']['Code'] == 'NoSuchKey':\n", " print(\"No violation file found\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.5 ('general')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "4852bb1f7cd44f51326f23dc402ead6dde438dc19e87d2e1ec37a0afdae1dc27" } } }, "nbformat": 4, "nbformat_minor": 2 }