{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "775f08d1-6f07-4777-a697-1f603b68a93d",
   "metadata": {},
   "source": [
    "# Build and Register an MXNet Image Classification Model via SageMaker Pipelines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a206ae5e-e21d-46ba-962c-d737edf98b26",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade pip",
    "\n",
    "!pip install --upgrade sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5ce22e9-81a3-4ec9-b109-c8aac18167bc",
   "metadata": {},
   "source": [
    "## 1. Create SageMaker session and client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a4f5fa-2d47-407c-99d2-f9c45492db37",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "import boto3\n",
    "import sagemaker\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "boto_session = boto3.Session(region_name=region)\n",
    "\n",
    "sm_client = boto3.Session().client(service_name=\"sagemaker\", region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2993ef2-a200-4499-b7d6-8af87de0a79f",
   "metadata": {},
   "source": [
    "## 2. Define SageMaker Pipeline parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c421c687-5571-4026-a6ce-761272e34068",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.workflow.pipeline_context import PipelineSession\n",
    "from sagemaker.workflow.parameters import (\n",
    "    ParameterInteger,\n",
    "    ParameterString,\n",
    "    ParameterFloat\n",
    ")\n",
    "\n",
    "\n",
    "# Setting a SageMaker Pipeline Session is important to avoid pipeline steps from running before the pipeline is ready\n",
    "sm_pipeline_session = PipelineSession(boto_session=boto_session, sagemaker_client=sm_client, default_bucket=bucket)\n",
    "\n",
    "model_package_group_name = \"MXNet-Image-Classification\"  # Model name in model registry\n",
    "prefix = \"mxnet-image-classification-pipeline\"\n",
    "pipeline_name = \"MXNetImageClassificationPipeline\" \n",
    "\n",
    "###### TODO ######\n",
    "input_img_data_s3_uri = \"s3://TODO\" # e.g. \"s3://my-image-bucket\"\n",
    "\n",
    "input_data = ParameterString(\n",
    "    name=\"InputData\",\n",
    "    default_value=input_img_data_s3_uri\n",
    ")\n",
    "\n",
    "processing_instance_count = ParameterInteger(name=\"ProcessingInstanceCount\", default_value=1)\n",
    "\n",
    "processing_instance_type = ParameterString(name=\"ProcessingInstanceType\", default_value=\"ml.t3.medium\")\n",
    "\n",
    "training_instance_type = ParameterString(name=\"TrainingInstanceType\", default_value=\"ml.p3.2xlarge\")\n",
    "\n",
    "train_split_percentage = ParameterFloat(\n",
    "    name=\"TrainSplitPercentage\",\n",
    "    default_value=0.75\n",
    ")\n",
    "\n",
    "validation_split_percentage = ParameterFloat(\n",
    "    name=\"ValidationSplitPercentage\",\n",
    "    default_value=0.10\n",
    ")\n",
    "\n",
    "test_split_percentage = ParameterFloat(\n",
    "    name=\"TestSplitPercentage\",\n",
    "    default_value=0.15\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b58493d2-8990-4e74-bd02-281dbde3895d",
   "metadata": {},
   "source": [
    "## 3. Define image preprocessing step\n",
    "\n",
    "In this ML workflow step, you will be converting the raw .jpg image files in the S3 input bucket to the Apache MXNet RecordIO format, which is the recommended input format for the Amazon SageMaker image classification algorithm. Upon completion, the newly generated .rec files will be split into train, validation, and test sets and uploaded back to the original S3 input bucket under a newly created `recordIO` folder. This step relies on the `preprocess.py` script found in the `scripts` directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2969fa1e-b993-475e-a9e6-9634a15ff5ff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.mxnet import MXNetProcessor\n",
    "from sagemaker.processing import ProcessingInput, ProcessingOutput\n",
    "from sagemaker.workflow.steps import ProcessingStep\n",
    "\n",
    "\n",
    "mxnet_processor_preprocess = MXNetProcessor(\n",
    "    framework_version=\"1.8.0\",\n",
    "    py_version=\"py37\",\n",
    "    instance_type=processing_instance_type.default_value,\n",
    "    instance_count=processing_instance_count.default_value,\n",
    "    base_job_name=f\"{prefix}/preprocess-image-data\",\n",
    "    sagemaker_session=sm_pipeline_session,\n",
    "    role=role\n",
    ")\n",
    "\n",
    "processing_inputs = [\n",
    "    ProcessingInput(\n",
    "        input_name=\"input_img_data\", \n",
    "        source=input_data,\n",
    "        destination=\"/opt/ml/processing/input/data/\"\n",
    "    )\n",
    "]\n",
    "\n",
    "processing_outputs = [\n",
    "    ProcessingOutput(\n",
    "        output_name=\"train\", \n",
    "        source=\"/opt/ml/processing/train\",\n",
    "        destination=f\"{input_data.default_value}/recordIO/train\"\n",
    "    ),\n",
    "    ProcessingOutput(\n",
    "        output_name=\"validation\", \n",
    "        source=\"/opt/ml/processing/validation\",\n",
    "        destination=f\"{input_data.default_value}/recordIO/validation\"\n",
    "    ),\n",
    "    ProcessingOutput(\n",
    "        output_name=\"test\", \n",
    "        source=\"/opt/ml/processing/test\",\n",
    "        destination=f\"{input_data.default_value}/recordIO/test\"\n",
    "    )\n",
    "]\n",
    "\n",
    "step_args = mxnet_processor_preprocess.run(\n",
    "    code=\"./scripts/preprocess.py\",\n",
    "    inputs=processing_inputs,\n",
    "    outputs=processing_outputs,\n",
    "    arguments=[\n",
    "        \"--input-s3-bucket\", \n",
    "        input_data.default_value,\n",
    "        \"--train-split-percentage\",\n",
    "        str(train_split_percentage.default_value),\n",
    "        \"--validation-split-percentage\",\n",
    "        str(validation_split_percentage.default_value),\n",
    "        \"--test-split-percentage\",\n",
    "        str(test_split_percentage.default_value)\n",
    "    ]\n",
    ")\n",
    "\n",
    "step_preprocess = ProcessingStep(\n",
    "    name=\"Preprocess-Image-Data\",\n",
    "    step_args=step_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f160617-a110-4329-9fb4-0d251ba57676",
   "metadata": {},
   "source": [
    "## 4. Define model training step\n",
    "\n",
    "In this ML workflow step, you will train an MXNet image classification model using the train and validation .rec files that were created in the previous step. For more information regarding the specific image classification algorithm, refer to [Image Classification - MXNet](https://docs.aws.amazon.com/sagemaker/latest/dg/image-classification.html) from the Amazon SageMaker documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdf744eb-ab88-455c-9f60-3f0bd33c58eb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.estimator import Estimator\n",
    "from sagemaker.inputs import TrainingInput\n",
    "from sagemaker.workflow.steps import TrainingStep\n",
    "\n",
    "\n",
    "model_output_path = f\"s3://{bucket}/{prefix}/model-output\"\n",
    "\n",
    "image_uri = sagemaker.image_uris.retrieve(\n",
    "    region=region,\n",
    "    framework=\"image-classification\"\n",
    ")\n",
    "\n",
    "mxnet_train = Estimator(\n",
    "    image_uri=image_uri,\n",
    "    instance_type=training_instance_type,\n",
    "    instance_count=1,\n",
    "    volume_size=50,\n",
    "    max_run=360000,\n",
    "    output_path=model_output_path,\n",
    "    sagemaker_session=sm_pipeline_session,\n",
    "    role=role\n",
    ")\n",
    "\n",
    "# Feel free to edit these model hyperparameters based on domain expertise\n",
    "# For more information regarding image classification hyperparameters, refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html\n",
    "mxnet_train.set_hyperparameters(\n",
    "    use_pretrained_model=1,\n",
    "    image_shape='3,224,224',\n",
    "    num_classes=2,\n",
    "    num_training_samples=750, ### TODO ### (total number of images uploaded * train_split_percentage) (1000 * 0.75)\n",
    "    learning_rate=0.1,\n",
    "    mini_batch_size=25,\n",
    "    epochs=30,\n",
    "    early_stopping=True\n",
    ")\n",
    "\n",
    "train_data = sagemaker.inputs.TrainingInput(\n",
    "    s3_data=step_preprocess.properties.ProcessingOutputConfig.Outputs[\"train\"].S3Output.S3Uri, \n",
    "    distribution='FullyReplicated', \n",
    "    content_type='application/x-recordio', \n",
    "    s3_data_type='S3Prefix'\n",
    ")\n",
    "\n",
    "validation_data = sagemaker.inputs.TrainingInput(\n",
    "    s3_data=step_preprocess.properties.ProcessingOutputConfig.Outputs[\"validation\"].S3Output.S3Uri, \n",
    "    distribution='FullyReplicated', \n",
    "    content_type='application/x-recordio', \n",
    "    s3_data_type='S3Prefix'\n",
    ")\n",
    "\n",
    "step_args = mxnet_train.fit(\n",
    "    inputs={\n",
    "        \"train\": train_data, \n",
    "        \"validation\": validation_data\n",
    "    },\n",
    "    logs=True\n",
    ")\n",
    "\n",
    "step_train = TrainingStep(\n",
    "    name=\"Train-Image-Classification-Model\",\n",
    "    step_args=step_args\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4efcad1-cdea-4de6-b8d5-b5421c561339",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 5. Define model evaluation step\n",
    "\n",
    "In this ML workflow step, you will be evaluating the trained model (from the previous step) on the test .rec file. Specifically, you will be measuring the accuracy and F1 score on the test set. This step relies on the `evaluate.py` script found in the `scripts` directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f2c03cc-5848-476f-ad63-8df6e1f5645c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.workflow.properties import PropertyFile\n",
    "\n",
    "mxnet_processor_eval = MXNetProcessor(\n",
    "    framework_version=\"1.8.0\",\n",
    "    py_version=\"py37\",\n",
    "    instance_type=processing_instance_type.default_value,\n",
    "    instance_count=processing_instance_count.default_value,\n",
    "    base_job_name=f\"{prefix}/model-evaluation\",\n",
    "    sagemaker_session=sm_pipeline_session,\n",
    "    role=role\n",
    "    \n",
    ")\n",
    "\n",
    "processing_inputs = [\n",
    "    ProcessingInput(\n",
    "        source=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n",
    "        destination=\"/opt/ml/processing/model\"\n",
    "    ),\n",
    "    ProcessingInput(\n",
    "        source=step_preprocess.properties.ProcessingOutputConfig.Outputs[\"test\"].S3Output.S3Uri,\n",
    "        destination=\"/opt/ml/processing/test\")\n",
    "]\n",
    "\n",
    "processing_outputs = [\n",
    "    ProcessingOutput(\n",
    "        output_name=\"evaluation\", \n",
    "        source=\"/opt/ml/processing/evaluation\",\n",
    "        destination=f\"s3://{bucket}/{prefix}/model-evaluation\")\n",
    "]\n",
    "\n",
    "step_args=mxnet_processor_eval.run(\n",
    "    code=\"./scripts/evaluate.py\",\n",
    "    inputs=processing_inputs,\n",
    "    outputs=processing_outputs\n",
    ")\n",
    "\n",
    "evaluation_report = PropertyFile(\n",
    "    name=\"Image-Classification-Model-Evaluation-Report\",\n",
    "    output_name=\"evaluation\",\n",
    "    path=\"evaluation.json\"\n",
    ")\n",
    "\n",
    "step_eval = ProcessingStep(\n",
    "    name=\"Evaluate-Image-Classification-Model\",\n",
    "    step_args=step_args,\n",
    "    property_files=[evaluation_report]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1a71b2-b5f6-4d33-a734-75969d767bf2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 6. Define model registering step\n",
    "\n",
    "In this ML workflow step, you will conditionally register the trained model into SageMaker model registry only if the trained model accuracy (on the test set) is greater than or equal to 70%.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b22867-83d7-4143-8ce1-994ff1ce6e3b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.model import Model\n",
    "from sagemaker.model_metrics import MetricsSource, ModelMetrics \n",
    "from sagemaker.workflow.model_step import ModelStep\n",
    "from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo\n",
    "from sagemaker.workflow.condition_step import ConditionStep\n",
    "from sagemaker.workflow.functions import JsonGet\n",
    "\n",
    "\n",
    "model_metrics = ModelMetrics(\n",
    "    model_statistics=MetricsSource(\n",
    "        s3_uri=\"{}/evaluation.json\".format(\n",
    "            step_eval.arguments[\"ProcessingOutputConfig\"][\"Outputs\"][0][\"S3Output\"][\"S3Uri\"]\n",
    "        ),\n",
    "        content_type=\"application/json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "model = Model(\n",
    "    image_uri=image_uri,\n",
    "    model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n",
    "    sagemaker_session=sm_pipeline_session,\n",
    "    role=role,\n",
    ")\n",
    "\n",
    "step_args = model.register(\n",
    "    content_types=[\"image/png\"],\n",
    "    response_types=[\"application/json\"],\n",
    "    inference_instances=[\"ml.m5.xlarge\"],\n",
    "    model_package_group_name=model_package_group_name,\n",
    "    model_metrics=model_metrics\n",
    ")\n",
    "\n",
    "step_register = ModelStep(\n",
    "    name=\"Register-Image-Classification-Model\",\n",
    "    step_args=step_args\n",
    ")\n",
    "\n",
    "# Create a condition to register the model if the model accuracy is greater than 0.70\n",
    "cond_gte = ConditionGreaterThanOrEqualTo(\n",
    "    left=JsonGet(\n",
    "        step_name=step_eval.name,\n",
    "        property_file=evaluation_report,\n",
    "        json_path=\"classification_metrics.accuracy.value\"\n",
    "    ),\n",
    "    right=0.70,\n",
    ")\n",
    "\n",
    "# This step encompasses 'step_register' and only performs the 'step_register' if the model accuracy is greater than 0.70\n",
    "step_cond = ConditionStep(\n",
    "    name=\"Check-Accuracy-Image-Classification-Model\",\n",
    "    conditions=[cond_gte],\n",
    "    if_steps=[step_register],\n",
    "    else_steps=[],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9babb47-bcde-4bb3-8aa4-a0cd44a6242a",
   "metadata": {},
   "source": [
    "## 7. Create SageMaker Pipeline\n",
    "\n",
    "In this step, you define a SageMaker pipeline encompassing all the above ML workflow steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "858ceb0f-1f1d-4789-8b9e-a2f11d8ec4ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.workflow.pipeline import Pipeline\n",
    "import json\n",
    "\n",
    "\n",
    "pipeline = Pipeline(\n",
    "    name=pipeline_name,\n",
    "    parameters=[\n",
    "        input_data,\n",
    "        processing_instance_count,\n",
    "        processing_instance_type,\n",
    "        training_instance_type,\n",
    "        train_split_percentage,\n",
    "        validation_split_percentage,\n",
    "        test_split_percentage\n",
    "    ],\n",
    "    steps=[\n",
    "        step_preprocess, \n",
    "        step_train, \n",
    "        step_eval, \n",
    "        step_cond\n",
    "    ],\n",
    "    sagemaker_session=sm_pipeline_session\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4110e48e-84c6-4118-95d7-878338d39edf",
   "metadata": {},
   "source": [
    "## 8. Start SageMaker Pipeline execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff9bb60-e852-4747-914c-1f8504496d6f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Submit the pipeline definition to the SageMaker Pipelines service to create a pipeline if it doesn't exist, or update the pipeline if it does\n",
    "pipeline.upsert(role_arn=role)\n",
    "\n",
    "# Start a pipeline execution\n",
    "execution = pipeline.start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "081f7d50-f1dd-4752-b52b-44abbb4fda82",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}