{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SageMaker Pipelines integration with Model Monitor and Clarify\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
"\n",
"\n",
"\n",
"---"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"This notebook showcases how Model Monitor and Clarify steps can be integrated with SageMaker Pipelines. This allows users to calculate\n",
"baselines for data quality and model quality checks by running the underlying Model Monitor and Clarify containers."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data/Model Quality, Bias, and Model Explainability Checks in SageMaker Pipelines\n",
"\n",
"This notebook introduces two new step types in SageMaker Pipelines -\n",
"* `QualityCheckStep`\n",
"* `ClarifyCheckStep`\n",
"\n",
"With these two steps, the pipeline is able to perform baseline calculations that are needed as a standard against which data/model quality issues can be detected (including bias and explainability).\n",
"\n",
"These steps leverage SageMaker pre-built containers:\n",
"\n",
"* `QualityCheckStep` (for Data/Model Quality): [sagemaker-model-monitor-analyzer](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-pre-built-container.html)\n",
"* `ClarifyCheckStep` (for Data/Model Bias and Model Explainability): [sagemaker-clarify-processing](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-configure-processing-jobs.html#clarify-processing-job-configure-container)\n",
"\n",
"The training dataset that you used to train the model is usually a good baseline dataset. The training dataset data schema and the inference dataset schema should exactly match (the number and order of the features). Note that the prediction/output columns are assumed to be the first columns in the training dataset. From the training dataset, you can ask SageMaker to suggest a set of baseline constraints and generate descriptive statistics to explore the data.\n",
"\n",
"These two new steps will always calculate new baselines using the dataset provided."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Drift Check Baselines in the Model Registry\n",
"\n",
"The `RegisterStep` has a new parameter called `drift_check_baselines`. This refers to the baseline files associated with the model. When deployed, these baseline files are used by Model Monitor for Model Quality/Data Quality checks. In addition, these baselines can be used in `QualityCheckStep` and `ClarifyCheckStep` to compare newly trained models against models that have already been registered in the Model Registry.\n",
"\n",
"### Step Properties\n",
"\n",
"The `QualityCheckStep` has the following properties -\n",
"\n",
"* `CalculatedBaselineStatistics` : The baseline statistics file calculated by the underlying Model Monitor container.\n",
"* `CalculatedBaselineConstraints` : The baseline constraints file calculated by the underlying Model Monitor container.\n",
"* `BaselineUsedForDriftCheckStatistics` and `BaselineUsedForDriftCheckConstraints` : These are the two properties used to set `drift_check_baseline` in the Model Registry. The values set in these properties vary depending on the parameters passed to the step. The different behaviors are described in the table below.\n",
"\n",
"The `ClarifyCheckStep` has the following properties -\n",
"\n",
"* `CalculatedBaselineConstraints` : The baseline constraints file calculated by the underlying Clarify container.\n",
"* `BaselineUsedForDriftCheckConstraints` : This property is used to set `drift_check_baseline` in the Model Registry. The values set in this property will vary depending on the parameters passed to the step. The different behaviors are described in the table below."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Notebook Overview\n",
"\n",
"This notebook should be run with `Python 3.9` using the SageMaker Studio `Python3 (Data Science)` kernel."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by installing the SageMaker Python SDK, boto, and AWS CLI."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"! pip install botocore boto3 awscli --upgrade\n",
"! pip install \"sagemaker>=2.99.0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import boto3\n",
"import sagemaker\n",
"import sagemaker.session\n",
"\n",
"from sagemaker import utils\n",
"from sagemaker.estimator import Estimator\n",
"from sagemaker.inputs import TrainingInput, CreateModelInput, TransformInput\n",
"from sagemaker.model import Model\n",
"from sagemaker.transformer import Transformer\n",
"\n",
"from sagemaker.model_metrics import MetricsSource, ModelMetrics, FileSource\n",
"from sagemaker.drift_check_baselines import DriftCheckBaselines\n",
"from sagemaker.processing import (\n",
" ProcessingInput,\n",
" ProcessingOutput,\n",
" ScriptProcessor,\n",
")\n",
"from sagemaker.sklearn.processing import SKLearnProcessor\n",
"from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo\n",
"from sagemaker.workflow.condition_step import ConditionStep\n",
"from sagemaker.workflow.functions import JsonGet\n",
"\n",
"from sagemaker.workflow.parameters import (\n",
" ParameterBoolean,\n",
" ParameterInteger,\n",
" ParameterString,\n",
")\n",
"from sagemaker.workflow.pipeline import Pipeline\n",
"from sagemaker.workflow.properties import PropertyFile\n",
"from sagemaker.workflow.steps import (\n",
" ProcessingStep,\n",
" TrainingStep,\n",
" CreateModelStep,\n",
" TransformStep,\n",
")\n",
"from sagemaker.workflow.model_step import ModelStep\n",
"from sagemaker.workflow.pipeline_context import PipelineSession\n",
"\n",
"# Importing new steps and helper functions\n",
"\n",
"from sagemaker.workflow.check_job_config import CheckJobConfig\n",
"from sagemaker.workflow.clarify_check_step import (\n",
" DataBiasCheckConfig,\n",
" ClarifyCheckStep,\n",
" ModelBiasCheckConfig,\n",
" ModelPredictedLabelConfig,\n",
" ModelExplainabilityCheckConfig,\n",
" SHAPConfig,\n",
")\n",
"from sagemaker.workflow.quality_check_step import (\n",
" DataQualityCheckConfig,\n",
" ModelQualityCheckConfig,\n",
" QualityCheckStep,\n",
")\n",
"from sagemaker.workflow.execution_variables import ExecutionVariables\n",
"from sagemaker.workflow.functions import Join\n",
"from sagemaker.model_monitor import DatasetFormat, model_monitoring\n",
"from sagemaker.clarify import BiasConfig, DataConfig, ModelConfig"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the SageMaker Session"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"region = sagemaker.Session().boto_region_name\n",
"sm_client = boto3.client(\"sagemaker\")\n",
"boto_session = boto3.Session(region_name=region)\n",
"sagemaker_session = sagemaker.session.Session(boto_session=boto_session, sagemaker_client=sm_client)\n",
"pipeline_session = PipelineSession()\n",
"prefix = \"model-monitor-clarify-step-pipeline\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define variables and parameters needed for the Pipeline steps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"role = sagemaker.get_execution_role()\n",
"default_bucket = sagemaker_session.default_bucket()\n",
"base_job_prefix = \"model-monitor-clarify\"\n",
"model_package_group_name = \"model-monitor-clarify-group\"\n",
"pipeline_name = \"model-monitor-clarify-pipeline\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define pipeline parameters\n",
"\n",
"Both `QualityCheckStep` and `ClarifyCheckStep` use two boolean flags `skip_check` and `register_new_baseline` to control their behavior.\n",
"\n",
"* `skip_check` : This determines if a drift check is executed or not.\n",
"* `register_new_baseline` : This determines if the newly calculated baselines (in the step property `CalculatedBaselines`) should be set in the step property `BaselineUsedForDriftCheck`.\n",
"* `supplied_baseline_statistics` and `supplied_baseline_constraints` : If `skip_check` is set to False, baselines can be provided to this step through this parameter. If provided, the step will compare the newly calculated baselines (`CalculatedBaselines`) against those provided here instead of finding the latest baselines from the Model Registry. In the case of `ClarifyCheckStep`, only `supplied_baseline_constraints` is a valid parameter, for `QualityCheckStep`, both parameters are used.\n",
"* `model_package_group_name` : The step will use the `drift_check_baselines` from the latest approved model in the model package group for the drift check. If `supplied_baseline_*` is provided, this field will be ignored.\n",
"\n",
"The first time the pipeline is run, the `skip_check` value should be set to True using the pipeline execution parameters so that new baselines are registered and no drift check is executed."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Combining Pipeline parameters\n",
"\n",
"This table summarizes how the pipeline parameters work when combined.\n",
"\n",
"The parameter `drift_check_baselines` is used to supply baselines to the `RegisterStep` that will be used for all drift checks involving the model.\n",
"\n",
"Newly calculated baselines can be reference by the properties `CalculatedBaselineStatistics` and `CalculatedBaselineConstraints` on the `QualityCheckStep` and `CalculatedBaselineConstraints` on the `ClarifyCheckStep`.\n",
"\n",
"For example, `data_quality_check_step.properties.CalculatedBaselineStatistics` and `data_quality_check_step.properties.CalculatedBaselineConstraints`. This property refers to the baseline that is calculated when the data quality check step is executed."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"| `skip_check` / `register_new_baseline` | Does step do a drift check? | Value of step property `CalculatedBaseline` | Value of step property `BaselineUsedForDriftCheck` | Possible Circumstances for this parameter combination|\n",
"| -------------------------------------- | ---------------------------------------------------------|------------------------------------------------------------ |------------------------------------------------- | -----------------------------------------------------|\n",
"| F / F | Drift Check executed against existing baselines. | New baselines calculated by step execution | Baseline from latest approved model in Model Registry or baseline supplied as step parameter | Regular re-training with checks enabled to get a new model version, but carry over previous baselines as DriftCheckBaselines in Registry for new model version. |\n",
"| F / T | Drift Check executed against existing baselines. | New baselines calculated by step execution | Newly calculated baseline by step execution (value of property `CalculatedBaseline`) | Regular re-training with checks enabled to get a new model version, but refresh DriftCheckBaselines in Registry with newly calculated baselines for the new model version. |\n",
"| T / F | No Drift Check. | New baselines calculated by step execution | Baseline from latest approved model in Model Registry or baseline supplied as step parameter | Violation detected by the model monitor on endpoint for a particular type of check and the pipeline is triggered for retraining a new model. Skip the check against previous baselines, but carry over previous baselines as DriftCheckBaselines in Registry for new model version. |\n",
"| T / T | No Drift Check. | New baselines calculated by step execution | Newly calculated baseline by step execution (value of property `CalculatedBaseline`) | a. Initial run of the pipeline, building the first model version and generate initial baselines.
b. Violation detected by the model monitor on endpoint for a particular type of check and the pipeline is triggered for retraining a new model. Skip the check against previous baselines and refresh DriftCheckBaselines with newly calculated baselines in Registry directly. |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"processing_instance_count = ParameterInteger(name=\"ProcessingInstanceCount\", default_value=1)\n",
"training_instance_type = ParameterString(name=\"TrainingInstanceType\", default_value=\"ml.m5.xlarge\")\n",
"model_approval_status = ParameterString(\n",
" name=\"ModelApprovalStatus\", default_value=\"PendingManualApproval\"\n",
")\n",
"# The dataset used here is the open source Abalone dataset that can be found\n",
"# here - https://archive.ics.uci.edu/ml/datasets/abalone\n",
"input_data = ParameterString(\n",
" name=\"InputDataUrl\",\n",
" default_value=f\"s3://sagemaker-example-files-prod-{region}/datasets/tabular/uci_abalone/abalone.csv\",\n",
")\n",
"\n",
"# for data quality check step\n",
"skip_check_data_quality = ParameterBoolean(name=\"SkipDataQualityCheck\", default_value=False)\n",
"register_new_baseline_data_quality = ParameterBoolean(\n",
" name=\"RegisterNewDataQualityBaseline\", default_value=False\n",
")\n",
"supplied_baseline_statistics_data_quality = ParameterString(\n",
" name=\"DataQualitySuppliedStatistics\", default_value=\"\"\n",
")\n",
"supplied_baseline_constraints_data_quality = ParameterString(\n",
" name=\"DataQualitySuppliedConstraints\", default_value=\"\"\n",
")\n",
"\n",
"# for data bias check step\n",
"skip_check_data_bias = ParameterBoolean(name=\"SkipDataBiasCheck\", default_value=False)\n",
"register_new_baseline_data_bias = ParameterBoolean(\n",
" name=\"RegisterNewDataBiasBaseline\", default_value=False\n",
")\n",
"supplied_baseline_constraints_data_bias = ParameterString(\n",
" name=\"DataBiasSuppliedBaselineConstraints\", default_value=\"\"\n",
")\n",
"\n",
"# for model quality check step\n",
"skip_check_model_quality = ParameterBoolean(name=\"SkipModelQualityCheck\", default_value=False)\n",
"register_new_baseline_model_quality = ParameterBoolean(\n",
" name=\"RegisterNewModelQualityBaseline\", default_value=False\n",
")\n",
"supplied_baseline_statistics_model_quality = ParameterString(\n",
" name=\"ModelQualitySuppliedStatistics\", default_value=\"\"\n",
")\n",
"supplied_baseline_constraints_model_quality = ParameterString(\n",
" name=\"ModelQualitySuppliedConstraints\", default_value=\"\"\n",
")\n",
"\n",
"# for model bias check step\n",
"skip_check_model_bias = ParameterBoolean(name=\"SkipModelBiasCheck\", default_value=False)\n",
"register_new_baseline_model_bias = ParameterBoolean(\n",
" name=\"RegisterNewModelBiasBaseline\", default_value=False\n",
")\n",
"supplied_baseline_constraints_model_bias = ParameterString(\n",
" name=\"ModelBiasSuppliedBaselineConstraints\", default_value=\"\"\n",
")\n",
"\n",
"# for model explainability check step\n",
"skip_check_model_explainability = ParameterBoolean(\n",
" name=\"SkipModelExplainabilityCheck\", default_value=False\n",
")\n",
"register_new_baseline_model_explainability = ParameterBoolean(\n",
" name=\"RegisterNewModelExplainabilityBaseline\", default_value=False\n",
")\n",
"supplied_baseline_constraints_model_explainability = ParameterString(\n",
" name=\"ModelExplainabilitySuppliedBaselineConstraints\", default_value=\"\"\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Processing step for feature engineering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!mkdir -p code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%%writefile code/preprocess.py\n",
"\n",
"\"\"\"Feature engineers the abalone dataset.\"\"\"\n",
"import argparse\n",
"import logging\n",
"import os\n",
"import pathlib\n",
"import requests\n",
"import tempfile\n",
"\n",
"import boto3\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.addHandler(logging.StreamHandler())\n",
"\n",
"\n",
"# Since we get a headerless CSV file we specify the column names here.\n",
"feature_columns_names = [\n",
" \"sex\",\n",
" \"length\",\n",
" \"diameter\",\n",
" \"height\",\n",
" \"whole_weight\",\n",
" \"shucked_weight\",\n",
" \"viscera_weight\",\n",
" \"shell_weight\",\n",
"]\n",
"label_column = \"rings\"\n",
"\n",
"feature_columns_dtype = {\n",
" \"sex\": str,\n",
" \"length\": np.float64,\n",
" \"diameter\": np.float64,\n",
" \"height\": np.float64,\n",
" \"whole_weight\": np.float64,\n",
" \"shucked_weight\": np.float64,\n",
" \"viscera_weight\": np.float64,\n",
" \"shell_weight\": np.float64,\n",
"}\n",
"label_column_dtype = {\"rings\": np.float64}\n",
"\n",
"\n",
"def merge_two_dicts(x, y):\n",
" \"\"\"Merges two dicts, returning a new copy.\"\"\"\n",
" z = x.copy()\n",
" z.update(y)\n",
" return z\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" logger.debug(\"Starting preprocessing.\")\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\"--input-data\", type=str, required=True)\n",
" args = parser.parse_args()\n",
"\n",
" base_dir = \"/opt/ml/processing\"\n",
" pathlib.Path(f\"{base_dir}/data\").mkdir(parents=True, exist_ok=True)\n",
" input_data = args.input_data\n",
" bucket = input_data.split(\"/\")[2]\n",
" key = \"/\".join(input_data.split(\"/\")[3:])\n",
"\n",
" logger.info(\"Downloading data from bucket: %s, key: %s\", bucket, key)\n",
" fn = f\"{base_dir}/data/abalone-dataset.csv\"\n",
" s3 = boto3.resource(\"s3\")\n",
" s3.Bucket(bucket).download_file(key, fn)\n",
"\n",
" logger.debug(\"Reading downloaded data.\")\n",
" df = pd.read_csv(\n",
" fn,\n",
" header=None,\n",
" names=feature_columns_names + [label_column],\n",
" dtype=merge_two_dicts(feature_columns_dtype, label_column_dtype),\n",
" )\n",
" os.unlink(fn)\n",
"\n",
" logger.debug(\"Defining transformers.\")\n",
" numeric_features = list(feature_columns_names)\n",
" numeric_features.remove(\"sex\")\n",
" numeric_transformer = Pipeline(\n",
" steps=[\n",
" (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
" (\"scaler\", StandardScaler()),\n",
" ]\n",
" )\n",
"\n",
" categorical_features = [\"sex\"]\n",
" categorical_transformer = Pipeline(\n",
" steps=[\n",
" (\"imputer\", SimpleImputer(strategy=\"constant\", fill_value=\"missing\")),\n",
" (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\")),\n",
" ]\n",
" )\n",
"\n",
" preprocess = ColumnTransformer(\n",
" transformers=[\n",
" (\"num\", numeric_transformer, numeric_features),\n",
" (\"cat\", categorical_transformer, categorical_features),\n",
" ]\n",
" )\n",
"\n",
" logger.info(\"Applying transforms.\")\n",
" y = df.pop(\"rings\")\n",
" X_pre = preprocess.fit_transform(df)\n",
" y_pre = y.to_numpy().reshape(len(y), 1)\n",
"\n",
" X = np.concatenate((y_pre, X_pre), axis=1)\n",
"\n",
" logger.info(\"Splitting %d rows of data into train, validation, test datasets.\", len(X))\n",
" np.random.shuffle(X)\n",
" train, validation, test = np.split(X, [int(0.7 * len(X)), int(0.85 * len(X))])\n",
"\n",
" logger.info(\"Writing out datasets to %s.\", base_dir)\n",
" pd.DataFrame(train).to_csv(f\"{base_dir}/train/train.csv\", header=False, index=False)\n",
" pd.DataFrame(validation).to_csv(\n",
" f\"{base_dir}/validation/validation.csv\", header=False, index=False\n",
" )\n",
" pd.DataFrame(test).to_csv(f\"{base_dir}/test/test.csv\", header=False, index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"sklearn_processor = SKLearnProcessor(\n",
" framework_version=\"0.23-1\",\n",
" instance_type=\"ml.m5.xlarge\",\n",
" instance_count=processing_instance_count,\n",
" base_job_name=f\"{base_job_prefix}/sklearn-abalone-preprocess\",\n",
" sagemaker_session=pipeline_session,\n",
" role=role,\n",
")\n",
"processor_args = sklearn_processor.run(\n",
" outputs=[\n",
" ProcessingOutput(output_name=\"train\", source=\"/opt/ml/processing/train\"),\n",
" ProcessingOutput(output_name=\"validation\", source=\"/opt/ml/processing/validation\"),\n",
" ProcessingOutput(output_name=\"test\", source=\"/opt/ml/processing/test\"),\n",
" ],\n",
" code=\"code/preprocess.py\",\n",
" arguments=[\"--input-data\", input_data],\n",
")\n",
"step_process = ProcessingStep(name=\"PreprocessAbaloneData\", step_args=processor_args)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculating the Data Quality\n",
"\n",
"`CheckJobConfig` is a helper function that's used to define the job configurations used by the `QualityCheckStep`. By separating the job configuration from the step parameters, the same `CheckJobConfig` can be used across multiple steps for quality checks.\n",
"\n",
"The `DataQualityCheckConfig` is used to define the Quality Check job by specifying the dataset used to calculate the baseline, in this case, the training dataset from the data processing step, the dataset format, in this case, a csv file with no headers, and the output path for the results of the data quality check."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"check_job_config = CheckJobConfig(\n",
" role=role,\n",
" instance_count=1,\n",
" instance_type=\"ml.c5.xlarge\",\n",
" volume_size_in_gb=120,\n",
" sagemaker_session=sagemaker_session,\n",
")\n",
"\n",
"data_quality_check_config = DataQualityCheckConfig(\n",
" baseline_dataset=step_process.properties.ProcessingOutputConfig.Outputs[\"train\"].S3Output.S3Uri,\n",
" dataset_format=DatasetFormat.csv(header=False, output_columns_position=\"START\"),\n",
" output_s3_uri=Join(\n",
" on=\"/\",\n",
" values=[\n",
" \"s3:/\",\n",
" default_bucket,\n",
" base_job_prefix,\n",
" ExecutionVariables.PIPELINE_EXECUTION_ID,\n",
" \"dataqualitycheckstep\",\n",
" ],\n",
" ),\n",
")\n",
"\n",
"data_quality_check_step = QualityCheckStep(\n",
" name=\"DataQualityCheckStep\",\n",
" skip_check=skip_check_data_quality,\n",
" register_new_baseline=register_new_baseline_data_quality,\n",
" quality_check_config=data_quality_check_config,\n",
" check_job_config=check_job_config,\n",
" supplied_baseline_statistics=supplied_baseline_statistics_data_quality,\n",
" supplied_baseline_constraints=supplied_baseline_constraints_data_quality,\n",
" model_package_group_name=model_package_group_name,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculating the Data Bias\n",
"\n",
"The job configuration from the previous step is used here and the `DataConfig` class is used to define how the `ClarifyCheckStep` should compute the data bias. The training dataset is used again for the bias evaluation, the column representing the label is specified through the `label` parameter, and a `BiasConfig` is provided.\n",
"\n",
"In the `BiasConfig`, we specify a facet name (the column that is the focal point of the bias calculation), the value of the facet that determines the range of values it can hold, and the threshold value for the label.\n",
"\n",
"More details on `BiasConfig` can be found [here](https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.clarify.BiasConfig)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"data_bias_analysis_cfg_output_path = (\n",
" f\"s3://{default_bucket}/{base_job_prefix}/databiascheckstep/analysis_cfg\"\n",
")\n",
"\n",
"data_bias_data_config = DataConfig(\n",
" s3_data_input_path=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"train\"\n",
" ].S3Output.S3Uri,\n",
" s3_output_path=Join(\n",
" on=\"/\",\n",
" values=[\n",
" \"s3:/\",\n",
" default_bucket,\n",
" base_job_prefix,\n",
" ExecutionVariables.PIPELINE_EXECUTION_ID,\n",
" \"databiascheckstep\",\n",
" ],\n",
" ),\n",
" label=0,\n",
" dataset_type=\"text/csv\",\n",
" s3_analysis_config_output_path=data_bias_analysis_cfg_output_path,\n",
")\n",
"\n",
"\n",
"data_bias_config = BiasConfig(\n",
" label_values_or_threshold=[15.0], facet_name=[8], facet_values_or_threshold=[[0.5]]\n",
")\n",
"\n",
"data_bias_check_config = DataBiasCheckConfig(\n",
" data_config=data_bias_data_config,\n",
" data_bias_config=data_bias_config,\n",
")\n",
"\n",
"data_bias_check_step = ClarifyCheckStep(\n",
" name=\"DataBiasCheckStep\",\n",
" clarify_check_config=data_bias_check_config,\n",
" check_job_config=check_job_config,\n",
" skip_check=skip_check_data_bias,\n",
" register_new_baseline=register_new_baseline_data_bias,\n",
" supplied_baseline_constraints=supplied_baseline_constraints_data_bias,\n",
" model_package_group_name=model_package_group_name,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train an XGBoost Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_path = f\"s3://{sagemaker_session.default_bucket()}/{base_job_prefix}/AbaloneTrain\"\n",
"image_uri = sagemaker.image_uris.retrieve(\n",
" framework=\"xgboost\",\n",
" region=region,\n",
" version=\"1.0-1\",\n",
" py_version=\"py3\",\n",
" instance_type=\"ml.m5.xlarge\",\n",
")\n",
"\n",
"xgb_train = Estimator(\n",
" image_uri=image_uri,\n",
" instance_type=training_instance_type,\n",
" instance_count=1,\n",
" output_path=model_path,\n",
" base_job_name=f\"{base_job_prefix}/abalone-train\",\n",
" sagemaker_session=pipeline_session,\n",
" role=role,\n",
")\n",
"\n",
"xgb_train.set_hyperparameters(\n",
" objective=\"reg:linear\",\n",
" num_round=50,\n",
" max_depth=5,\n",
" eta=0.2,\n",
" gamma=4,\n",
" min_child_weight=6,\n",
" subsample=0.7,\n",
" silent=0,\n",
")\n",
"\n",
"train_args = xgb_train.fit(\n",
" inputs={\n",
" \"train\": TrainingInput(\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\"train\"].S3Output.S3Uri,\n",
" content_type=\"text/csv\",\n",
" ),\n",
" \"validation\": TrainingInput(\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"validation\"\n",
" ].S3Output.S3Uri,\n",
" content_type=\"text/csv\",\n",
" ),\n",
" },\n",
")\n",
"step_train = TrainingStep(\n",
" name=\"TrainAbaloneModel\",\n",
" step_args=train_args,\n",
" depends_on=[data_bias_check_step.name, data_quality_check_step.name],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the model\n",
"\n",
"The model is created so that a batch transform job can be used to get predictions from the model on a test dataset. These predictions are used when calculating model quality, model bias, and model explainability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model = Model(\n",
" image_uri=image_uri,\n",
" model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n",
" sagemaker_session=pipeline_session,\n",
" role=role,\n",
")\n",
"\n",
"step_create_model = ModelStep(\n",
" name=\"AbaloneCreateModel\",\n",
" step_args=model.create(instance_type=\"ml.m5.large\", accelerator_type=\"ml.eia1.medium\"),\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Transform Output\n",
"\n",
"The output of the transform step combines the prediction and the input label. The output format is
\n",
"`prediction, original label`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"transformer = Transformer(\n",
" model_name=step_create_model.properties.ModelName,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" instance_count=1,\n",
" accept=\"text/csv\",\n",
" assemble_with=\"Line\",\n",
" output_path=f\"s3://{default_bucket}/AbaloneTransform\",\n",
")\n",
"\n",
"step_transform = TransformStep(\n",
" name=\"AbaloneTransform\",\n",
" transformer=transformer,\n",
" inputs=TransformInput(\n",
" data=step_process.properties.ProcessingOutputConfig.Outputs[\"test\"].S3Output.S3Uri,\n",
" input_filter=\"$[1:]\",\n",
" join_source=\"Input\",\n",
" output_filter=\"$[0,-1]\",\n",
" content_type=\"text/csv\",\n",
" split_type=\"Line\",\n",
" ),\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check the Model Quality\n",
"\n",
"In this `QualityCheckStep` we calculate the baselines for statistics and constraints using the predictions that the model generates from the test dataset (output from the TransformStep). We define the problem type as 'Regression' in the `ModelQualityCheckConfig` along with specifying the columns which represent the input and output. Since the dataset has no headers, `_c0`, `_c1` are auto-generated header names that should be used in the `ModelQualityCheckConfig`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_quality_check_config = ModelQualityCheckConfig(\n",
" baseline_dataset=step_transform.properties.TransformOutput.S3OutputPath,\n",
" dataset_format=DatasetFormat.csv(header=False),\n",
" output_s3_uri=Join(\n",
" on=\"/\",\n",
" values=[\n",
" \"s3:/\",\n",
" default_bucket,\n",
" base_job_prefix,\n",
" ExecutionVariables.PIPELINE_EXECUTION_ID,\n",
" \"modelqualitycheckstep\",\n",
" ],\n",
" ),\n",
" problem_type=\"Regression\",\n",
" inference_attribute=\"_c0\", # use auto-populated headers since we don't have headers in the dataset\n",
" ground_truth_attribute=\"_c1\", # use auto-populated headers since we don't have headers in the dataset\n",
")\n",
"\n",
"model_quality_check_step = QualityCheckStep(\n",
" name=\"ModelQualityCheckStep\",\n",
" skip_check=skip_check_model_quality,\n",
" register_new_baseline=register_new_baseline_model_quality,\n",
" quality_check_config=model_quality_check_config,\n",
" check_job_config=check_job_config,\n",
" supplied_baseline_statistics=supplied_baseline_statistics_model_quality,\n",
" supplied_baseline_constraints=supplied_baseline_constraints_model_quality,\n",
" model_package_group_name=model_package_group_name,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check for Model Bias\n",
"\n",
"Similar to the Data Bias check step, a `BiasConfig` is defined and Clarify is used to calculate the model bias using the training dataset and the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_bias_analysis_cfg_output_path = (\n",
" f\"s3://{default_bucket}/{base_job_prefix}/modelbiascheckstep/analysis_cfg\"\n",
")\n",
"\n",
"model_bias_data_config = DataConfig(\n",
" s3_data_input_path=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"train\"\n",
" ].S3Output.S3Uri,\n",
" s3_output_path=Join(\n",
" on=\"/\",\n",
" values=[\n",
" \"s3:/\",\n",
" default_bucket,\n",
" base_job_prefix,\n",
" ExecutionVariables.PIPELINE_EXECUTION_ID,\n",
" \"modelbiascheckstep\",\n",
" ],\n",
" ),\n",
" s3_analysis_config_output_path=model_bias_analysis_cfg_output_path,\n",
" label=0,\n",
" dataset_type=\"text/csv\",\n",
")\n",
"\n",
"model_config = ModelConfig(\n",
" model_name=step_create_model.properties.ModelName,\n",
" instance_count=1,\n",
" instance_type=\"ml.m5.xlarge\",\n",
")\n",
"\n",
"# We are using this bias config to configure Clarify to detect bias based on the first feature in the featurized vector for Sex\n",
"model_bias_config = BiasConfig(\n",
" label_values_or_threshold=[15.0], facet_name=[8], facet_values_or_threshold=[[0.5]]\n",
")\n",
"\n",
"model_bias_check_config = ModelBiasCheckConfig(\n",
" data_config=model_bias_data_config,\n",
" data_bias_config=model_bias_config,\n",
" model_config=model_config,\n",
" model_predicted_label_config=ModelPredictedLabelConfig(),\n",
")\n",
"\n",
"model_bias_check_step = ClarifyCheckStep(\n",
" name=\"ModelBiasCheckStep\",\n",
" clarify_check_config=model_bias_check_config,\n",
" check_job_config=check_job_config,\n",
" skip_check=skip_check_model_bias,\n",
" register_new_baseline=register_new_baseline_model_bias,\n",
" supplied_baseline_constraints=supplied_baseline_constraints_model_bias,\n",
" model_package_group_name=model_package_group_name,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check Model Explainability\n",
"\n",
"SageMaker Clarify uses a model-agnostic feature attribution approach, which you can use to understand why a model made a prediction after training and to provide per-instance explanation during inference. The implementation includes a scalable and efficient implementation of SHAP, based on the concept of a Shapley value from the field of cooperative game theory that assigns each feature an importance value for a particular prediction.\n",
"\n",
"For Model Explainability, Clarify requires an explainability configuration to be provided. In this example, we use `SHAPConfig`. For more information of `explainability_config`, visit the [Clarify documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_explainability_analysis_cfg_output_path = \"s3://{}/{}/{}/{}\".format(\n",
" default_bucket, base_job_prefix, \"modelexplainabilitycheckstep\", \"analysis_cfg\"\n",
")\n",
"\n",
"model_explainability_data_config = DataConfig(\n",
" s3_data_input_path=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"train\"\n",
" ].S3Output.S3Uri,\n",
" s3_output_path=Join(\n",
" on=\"/\",\n",
" values=[\n",
" \"s3:/\",\n",
" default_bucket,\n",
" base_job_prefix,\n",
" ExecutionVariables.PIPELINE_EXECUTION_ID,\n",
" \"modelexplainabilitycheckstep\",\n",
" ],\n",
" ),\n",
" s3_analysis_config_output_path=model_explainability_analysis_cfg_output_path,\n",
" label=0,\n",
" dataset_type=\"text/csv\",\n",
")\n",
"shap_config = SHAPConfig(seed=123, num_samples=10)\n",
"model_explainability_check_config = ModelExplainabilityCheckConfig(\n",
" data_config=model_explainability_data_config,\n",
" model_config=model_config,\n",
" explainability_config=shap_config,\n",
")\n",
"model_explainability_check_step = ClarifyCheckStep(\n",
" name=\"ModelExplainabilityCheckStep\",\n",
" clarify_check_config=model_explainability_check_config,\n",
" check_job_config=check_job_config,\n",
" skip_check=skip_check_model_explainability,\n",
" register_new_baseline=register_new_baseline_model_explainability,\n",
" supplied_baseline_constraints=supplied_baseline_constraints_model_explainability,\n",
" model_package_group_name=model_package_group_name,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate the performance of the model\n",
"\n",
"Using a processing job, evaluate the performance of the model. The performance is used in the Condition Step to determine if the model should be registered or not."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%%writefile code/evaluate.py\n",
"\n",
"\"\"\"Evaluation script for measuring mean squared error.\"\"\"\n",
"import json\n",
"import logging\n",
"import pathlib\n",
"import pickle\n",
"import tarfile\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import xgboost\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.addHandler(logging.StreamHandler())\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" logger.debug(\"Starting evaluation.\")\n",
" model_path = \"/opt/ml/processing/model/model.tar.gz\"\n",
" with tarfile.open(model_path) as tar:\n",
" tar.extractall(path=\".\")\n",
"\n",
" logger.debug(\"Loading xgboost model.\")\n",
" model = pickle.load(open(\"xgboost-model\", \"rb\"))\n",
"\n",
" logger.debug(\"Reading test data.\")\n",
" test_path = \"/opt/ml/processing/test/test.csv\"\n",
" df = pd.read_csv(test_path, header=None)\n",
"\n",
" logger.debug(\"Reading test data.\")\n",
" y_test = df.iloc[:, 0].to_numpy()\n",
" df.drop(df.columns[0], axis=1, inplace=True)\n",
" X_test = xgboost.DMatrix(df.values)\n",
"\n",
" logger.info(\"Performing predictions against test data.\")\n",
" predictions = model.predict(X_test)\n",
"\n",
" logger.debug(\"Calculating mean squared error.\")\n",
" mse = mean_squared_error(y_test, predictions)\n",
" std = np.std(y_test - predictions)\n",
" report_dict = {\n",
" \"regression_metrics\": {\n",
" \"mse\": {\"value\": mse, \"standard_deviation\": std},\n",
" },\n",
" }\n",
"\n",
" output_dir = \"/opt/ml/processing/evaluation\"\n",
" pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)\n",
"\n",
" logger.info(\"Writing out evaluation report with mse: %f\", mse)\n",
" evaluation_path = f\"{output_dir}/evaluation.json\"\n",
" with open(evaluation_path, \"w\") as f:\n",
" f.write(json.dumps(report_dict))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"script_eval = ScriptProcessor(\n",
" image_uri=image_uri,\n",
" command=[\"python3\"],\n",
" instance_type=\"ml.m5.xlarge\",\n",
" instance_count=1,\n",
" base_job_name=f\"{base_job_prefix}/script-abalone-eval\",\n",
" sagemaker_session=pipeline_session,\n",
" role=role,\n",
")\n",
"evaluation_report = PropertyFile(\n",
" name=\"AbaloneEvaluationReport\",\n",
" output_name=\"evaluation\",\n",
" path=\"evaluation.json\",\n",
")\n",
"\n",
"eval_args = script_eval.run(\n",
" inputs=[\n",
" ProcessingInput(\n",
" source=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n",
" destination=\"/opt/ml/processing/model\",\n",
" ),\n",
" ProcessingInput(\n",
" source=step_process.properties.ProcessingOutputConfig.Outputs[\"test\"].S3Output.S3Uri,\n",
" destination=\"/opt/ml/processing/test\",\n",
" ),\n",
" ],\n",
" outputs=[\n",
" ProcessingOutput(output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\"),\n",
" ],\n",
" code=\"code/evaluate.py\",\n",
")\n",
"step_eval = ProcessingStep(\n",
" name=\"EvaluateAbaloneModel\",\n",
" step_args=eval_args,\n",
" property_files=[evaluation_report],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the metrics to be registered with the model in the Model Registry"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_metrics = ModelMetrics(\n",
" model_data_statistics=MetricsSource(\n",
" s3_uri=data_quality_check_step.properties.CalculatedBaselineStatistics,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_data_constraints=MetricsSource(\n",
" s3_uri=data_quality_check_step.properties.CalculatedBaselineConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" bias_pre_training=MetricsSource(\n",
" s3_uri=data_bias_check_step.properties.CalculatedBaselineConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_statistics=MetricsSource(\n",
" s3_uri=model_quality_check_step.properties.CalculatedBaselineStatistics,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_constraints=MetricsSource(\n",
" s3_uri=model_quality_check_step.properties.CalculatedBaselineConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" bias_post_training=MetricsSource(\n",
" s3_uri=model_bias_check_step.properties.CalculatedBaselineConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" explainability=MetricsSource(\n",
" s3_uri=model_explainability_check_step.properties.CalculatedBaselineConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
")\n",
"\n",
"drift_check_baselines = DriftCheckBaselines(\n",
" model_data_statistics=MetricsSource(\n",
" s3_uri=data_quality_check_step.properties.BaselineUsedForDriftCheckStatistics,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_data_constraints=MetricsSource(\n",
" s3_uri=data_quality_check_step.properties.BaselineUsedForDriftCheckConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" bias_pre_training_constraints=MetricsSource(\n",
" s3_uri=data_bias_check_step.properties.BaselineUsedForDriftCheckConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" bias_config_file=FileSource(\n",
" s3_uri=model_bias_check_config.monitoring_analysis_config_uri,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_statistics=MetricsSource(\n",
" s3_uri=model_quality_check_step.properties.BaselineUsedForDriftCheckStatistics,\n",
" content_type=\"application/json\",\n",
" ),\n",
" model_constraints=MetricsSource(\n",
" s3_uri=model_quality_check_step.properties.BaselineUsedForDriftCheckConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" bias_post_training_constraints=MetricsSource(\n",
" s3_uri=model_bias_check_step.properties.BaselineUsedForDriftCheckConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" explainability_constraints=MetricsSource(\n",
" s3_uri=model_explainability_check_step.properties.BaselineUsedForDriftCheckConstraints,\n",
" content_type=\"application/json\",\n",
" ),\n",
" explainability_config_file=FileSource(\n",
" s3_uri=model_explainability_check_config.monitoring_analysis_config_uri,\n",
" content_type=\"application/json\",\n",
" ),\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Register the model\n",
"\n",
"The two parameters in `RegisterModel` that hold the metrics calculated by the `ClarifyCheckStep` and `QualityCheckStep` are `model_metrics` and `drift_check_baselines`.\n",
"\n",
"`drift_check_baselines` - these are the baseline files that will be used for drift checks in `QualityCheckStep` or `ClarifyCheckStep` and model monitoring jobs that are set up on endpoints hosting this model.\n",
"\n",
"`model_metrics` - these should be the latest baselines calculated in the pipeline run. This can be set using the step property `CalculatedBaseline`\n",
"\n",
"The intention behind these parameters is to give users a way to configure the baselines associated with a model so they can be used in drift checks or model monitoring jobs. Each time a pipeline is executed, users can choose to update the `drift_check_baselines` with newly calculated baselines. The `model_metrics` can be used to register the newly calculated baselines or any other metrics associated with the model.\n",
"\n",
"Every time a baseline is calculated, it is not necessary that the baselines used for drift checks are updated to the newly calculated baselines. In some cases, users may retain an older version of the baseline file to be used for drift checks and not register new baselines that are calculated in the Pipeline run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"register_args = model.register(\n",
" content_types=[\"text/csv\"],\n",
" response_types=[\"text/csv\"],\n",
" inference_instances=[\"ml.t2.medium\", \"ml.m5.large\"],\n",
" transform_instances=[\"ml.m5.large\"],\n",
" model_package_group_name=model_package_group_name,\n",
" approval_status=model_approval_status,\n",
" model_metrics=model_metrics,\n",
" drift_check_baselines=drift_check_baselines,\n",
")\n",
"\n",
"step_register = ModelStep(name=\"RegisterAbaloneModel\", step_args=register_args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# condition step for evaluating model quality and branching execution\n",
"cond_lte = ConditionLessThanOrEqualTo(\n",
" left=JsonGet(\n",
" step_name=step_eval.name,\n",
" property_file=evaluation_report,\n",
" json_path=\"regression_metrics.mse.value\",\n",
" ),\n",
" right=6.0,\n",
")\n",
"step_cond = ConditionStep(\n",
" name=\"CheckMSEAbaloneEvaluation\",\n",
" conditions=[cond_lte],\n",
" if_steps=[step_register],\n",
" else_steps=[],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the Pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# pipeline instance\n",
"pipeline = Pipeline(\n",
" name=pipeline_name,\n",
" parameters=[\n",
" processing_instance_count,\n",
" training_instance_type,\n",
" model_approval_status,\n",
" input_data,\n",
" skip_check_data_quality,\n",
" register_new_baseline_data_quality,\n",
" supplied_baseline_statistics_data_quality,\n",
" supplied_baseline_constraints_data_quality,\n",
" skip_check_data_bias,\n",
" register_new_baseline_data_bias,\n",
" supplied_baseline_constraints_data_bias,\n",
" skip_check_model_quality,\n",
" register_new_baseline_model_quality,\n",
" supplied_baseline_statistics_model_quality,\n",
" supplied_baseline_constraints_model_quality,\n",
" skip_check_model_bias,\n",
" register_new_baseline_model_bias,\n",
" supplied_baseline_constraints_model_bias,\n",
" skip_check_model_explainability,\n",
" register_new_baseline_model_explainability,\n",
" supplied_baseline_constraints_model_explainability,\n",
" ],\n",
" steps=[\n",
" step_process,\n",
" data_quality_check_step,\n",
" data_bias_check_step,\n",
" step_train,\n",
" step_create_model,\n",
" step_transform,\n",
" model_quality_check_step,\n",
" model_bias_check_step,\n",
" model_explainability_check_step,\n",
" step_eval,\n",
" step_cond,\n",
" ],\n",
" sagemaker_session=pipeline_session,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get Pipeline definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import json\n",
"\n",
"definition = json.loads(pipeline.definition())\n",
"definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"pipeline.upsert(role_arn=role)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### First time executing\n",
"\n",
"The first time the pipeline is run the parameters need to be overridden so that the checks are skipped and newly calculated baselines are registered"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"execution = pipeline.start(\n",
" parameters=dict(\n",
" SkipDataQualityCheck=True,\n",
" RegisterNewDataQualityBaseline=True,\n",
" SkipDataBiasCheck=True,\n",
" RegisterNewDataBiasBaseline=True,\n",
" SkipModelQualityCheck=True,\n",
" RegisterNewModelQualityBaseline=True,\n",
" SkipModelBiasCheck=True,\n",
" RegisterNewModelBiasBaseline=True,\n",
" SkipModelExplainabilityCheck=True,\n",
" RegisterNewModelExplainabilityBaseline=True,\n",
" )\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wait for the pipeline execution to complete"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"execution.wait()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cleaning up resources\n",
"\n",
"Users are responsible for cleaning up resources created when running this notebook. Specify the ModelName, ModelPackageName, and ModelPackageGroupName that need to be deleted. The model names are generated by the CreateModel step of the Pipeline and the property values are available only in the Pipeline context. To delete the models created by this pipeline, navigate to the Model Registry and Console to find the models to delete.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Create a SageMaker client\n",
"# sm_client = boto3.client(\"sagemaker\")\n",
"\n",
"# # Delete SageMaker Models\n",
"# sm_client.delete_model(ModelName=\"...\")\n",
"\n",
"# # Delete Model Packages\n",
"# sm_client.delete_model_package(ModelPackageName=\"...\")\n",
"\n",
"# # Delete the Model Package Group\n",
"# sm_client.delete_model_package_group(ModelPackageGroupName=\"model-monitor-clarify-group\")\n",
"\n",
"# # Delete the Pipeline\n",
"# sm_client.delete_pipeline(PipelineName=\"model-monitor-clarify-pipeline\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notebook CI Test Results\n",
"\n",
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (Data Science 3.0)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}