{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6b39bb5-1408-45ab-af98-fa786ce6d55f",
   "metadata": {},
   "source": [
    "# Step 2: Add SageMaker processing and training jobs\n",
    "In this step you move data processing and model training into [SageMaker Docker containers](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers.html) and use [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/index.html) to interact with SageMaker.\n",
    "\n",
    "![](img/six-steps-2.png)\n",
    "\n",
    "SageMaker makes use of Docker containers to enable developers to process data, train and deploy models. Containers allow developers and data scientists to package software into standardized units that run consistently on any platform that supports Docker. Containers ensure that code, runtime, system tools, system libraries, and settings are all in the same place, isolating them from the execution environment. It guarantees a consistent runtime experience regardless of where a container is being run.\n",
    "\n",
    "SageMaker also provides pre-build containers with popular data processing frameworks and ML algorithms. All SageMaker built-in algorithms are delivered as Docker containers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e4e390-2278-4a1c-8886-13d67507c818",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import boto3\n",
    "import botocore\n",
    "import numpy as np  \n",
    "import pandas as pd  \n",
    "import sagemaker\n",
    "from time import gmtime, strftime, sleep\n",
    "from sagemaker.sklearn.processing import SKLearnProcessor\n",
    "from sagemaker.processing import ProcessingInput, ProcessingOutput\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sagemaker.experiments.run import Run, load_run\n",
    "\n",
    "sagemaker.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c357f9a-0df7-4358-a231-62aaaeec605f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%store -r \n",
    "\n",
    "%store\n",
    "\n",
    "try:\n",
    "    initialized\n",
    "except NameError:\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] YOU HAVE TO RUN 00-start-here notebook   \")\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5214e551-4800-47c2-9eae-337f4fc5ed7f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "session = sagemaker.Session()\n",
    "sm = session.sagemaker_client"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5beedac-1847-47eb-b294-5eac2ecaa7ef",
   "metadata": {},
   "source": [
    "## Load or create an experiment\n",
    "Load the existing experiment we created in the previous notebook. You're going to track new runs in the same experiment.\n",
    "You can also create a new experiment to track runs in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e798262-7ee3-45e8-86b0-da6ff1cf9a61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Uncomment code block (Cmd + /) if you would like to create a new experiment\n",
    "# experiment_name = f\"from-idea-to-prod-experiment-{strftime('%d-%H-%M-%S', gmtime())}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b39c88c-6aed-47b4-8048-a27f35444e10",
   "metadata": {},
   "source": [
    "## Process data\n",
    "Use [SageMaker Processing](https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html) by simply providing a Python data preprocessing script and choosing a [SageMaker SDK processor](https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_processing.html) class.\n",
    "You must upload the input data to S3 and specify an S3 location for output data. SageMaker Processing automatically loads the input data from S3 and uploads transformed data back to S3 when the job is complete. The processing container image can either be an Amazon SageMaker built-in image or a custom image that you provide. The underlying infrastructure for a Processing job is fully managed by Amazon SageMaker. Cluster resources are provisioned for the duration of your job, and cleaned up when a job completes.\n",
    "\n",
    "![](img/sagemaker-processing.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b4f1fe-ef68-4151-8d7d-28269e39b68a",
   "metadata": {},
   "source": [
    "Your input data must be stored in an Amazon S3 bucket. Alternatively, you can use [Amazon Athena](https://sagemaker.readthedocs.io/en/stable/api/utility/inputs.html#sagemaker.dataset_definition.inputs.AthenaDatasetDefinition) or [Amazon Redshift](https://sagemaker.readthedocs.io/en/stable/api/utility/inputs.html#sagemaker.dataset_definition.inputs.AthenaDatasetDefinition) as input sources.\n",
    "\n",
    "Upload the input dataset to an Amazon S3 bucket:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01ba50c3-8428-4756-9a2c-8b6f8d19a52b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_s3_url = session.upload_data(\n",
    "    path=\"data/bank-additional/bank-additional-full.csv\",\n",
    "    bucket=bucket_name,\n",
    "    key_prefix=f\"{bucket_prefix}/input\"\n",
    ")\n",
    "\n",
    "%store input_s3_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "118099a1-e789-435e-a309-4dd9043f219d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls {bucket_name}/{bucket_prefix} --recursive"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a547ea17-c14a-41b4-a0d8-26164cf24615",
   "metadata": {},
   "source": [
    "Create a Python script by moving the data processing code from the step 1 notebook to a .py file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25f6f89d-c9fa-4432-bf3e-5cc21aeb2e6f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%writefile preprocessing.py\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "def _parse_args():\n",
    "    \n",
    "    parser = argparse.ArgumentParser()\n",
    "    # Data, model, and output directories\n",
    "    # model_dir is always passed in from SageMaker. By default this is a S3 path under the default bucket.\n",
    "    parser.add_argument('--filepath', type=str, default='/opt/ml/processing/input/')\n",
    "    parser.add_argument('--filename', type=str, default='bank-additional-full.csv')\n",
    "    parser.add_argument('--outputpath', type=str, default='/opt/ml/processing/output/')\n",
    "    \n",
    "    return parser.parse_known_args()\n",
    "\n",
    "\n",
    "if __name__==\"__main__\":\n",
    "    # Process arguments\n",
    "    args, _ = _parse_args()\n",
    "    \n",
    "    target_col = \"y\"\n",
    "    \n",
    "    # Load data\n",
    "    df_data = pd.read_csv(os.path.join(args.filepath, args.filename), sep=\";\")\n",
    "\n",
    "    # Indicator variable to capture when pdays takes a value of 999\n",
    "    df_data[\"no_previous_contact\"] = np.where(df_data[\"pdays\"] == 999, 1, 0)\n",
    "\n",
    "    # Indicator for individuals not actively employed\n",
    "    df_data[\"not_working\"] = np.where(\n",
    "        np.in1d(df_data[\"job\"], [\"student\", \"retired\", \"unemployed\"]), 1, 0\n",
    "    )\n",
    "\n",
    "    # remove unnecessary data\n",
    "    df_model_data = df_data.drop(\n",
    "        [\"duration\", \"emp.var.rate\", \"cons.price.idx\", \"cons.conf.idx\", \"euribor3m\", \"nr.employed\"],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    df_model_data = pd.get_dummies(df_model_data)  # Convert categorical variables to sets of indicators\n",
    "\n",
    "    # Replace \"y_no\" and \"y_yes\" with a single label column, and bring it to the front:\n",
    "    df_model_data = pd.concat(\n",
    "        [\n",
    "            df_model_data[\"y_yes\"].rename(target_col),\n",
    "            df_model_data.drop([\"y_no\", \"y_yes\"], axis=1),\n",
    "        ],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    # Shuffle and splitting dataset\n",
    "    train_data, validation_data, test_data = np.split(\n",
    "        df_model_data.sample(frac=1, random_state=1729),\n",
    "        [int(0.7 * len(df_model_data)), int(0.9 * len(df_model_data))],\n",
    "    )\n",
    "\n",
    "    print(f\"Data split > train:{train_data.shape} | validation:{validation_data.shape} | test:{test_data.shape}\")\n",
    "    \n",
    "    # Save datasets locally\n",
    "    train_data.to_csv(os.path.join(args.outputpath, 'train/train.csv'), index=False, header=False)\n",
    "    validation_data.to_csv(os.path.join(args.outputpath, 'validation/validation.csv'), index=False, header=False)\n",
    "    test_data[target_col].to_csv(os.path.join(args.outputpath, 'test/test_y.csv'), index=False, header=False)\n",
    "    test_data.drop([target_col], axis=1).to_csv(os.path.join(args.outputpath, 'test/test_x.csv'), index=False, header=False)\n",
    "    \n",
    "    # Save the baseline dataset for model monitoring\n",
    "    df_model_data.drop([target_col], axis=1).to_csv(os.path.join(args.outputpath, 'baseline/baseline.csv'), index=False, header=False)\n",
    "    \n",
    "    print(\"## Processing complete. Exiting.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317a90ba-5610-43fb-a39c-6eae4e2b6691",
   "metadata": {},
   "source": [
    "The processing script contains a statement to save the whole dataset without the header and the label column as a baseline dataset. You need the data baseline later on in the model monitoring notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70ee5a09-46fd-4d7a-9037-9b6247a1f21d",
   "metadata": {},
   "source": [
    "Set the Amazon S3 paths:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fa1abd-c0ae-4e55-901c-7250c227ccbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/train\"\n",
    "validation_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/validation\"\n",
    "test_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/test\"\n",
    "baseline_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/baseline\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed6a2d8e-acde-48a8-9f12-a1cd0d703a41",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%store train_s3_url\n",
    "%store validation_s3_url\n",
    "%store test_s3_url\n",
    "%store baseline_s3_url"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03762d62-0f9c-4766-8f5e-d554bfc8437c",
   "metadata": {},
   "source": [
    "Set the framework version and type and number of compute instances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93215af-1524-43d0-b08b-ab265936ec2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "framework_version = \"0.23-1\"\n",
    "processing_instance_type = \"ml.m5.large\"\n",
    "processing_instance_count = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b730b8ee-f2fe-4612-870a-edb0e0ffc276",
   "metadata": {},
   "source": [
    "### Create a run\n",
    "Create a new run in your experiment to track parameters, configuration, inputs, and outputs of the processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e28192-db7d-4a93-8c55-de2abd6fedf6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "run_suffix = strftime('%Y-%m-%M-%S', gmtime())\n",
    "run_name = f\"container-processing-{run_suffix}\"\n",
    "\n",
    "with Run(experiment_name=experiment_name,\n",
    "         run_name=run_name,\n",
    "         run_display_name=\"container-processing\",\n",
    "         sagemaker_session=session\n",
    "        ) as run:\n",
    "    run.log_parameters(\n",
    "        {\n",
    "            \"train\": 0.7,\n",
    "            \"validate\": 0.2,\n",
    "            \"test\": 0.1\n",
    "        }\n",
    "    )\n",
    "   \n",
    "    experiment_config = run.experiment_config\n",
    "    # time.sleep(8) # wait until resource tags are propagated to the run"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b925880-49f1-4571-9761-02c4a96dbd62",
   "metadata": {},
   "source": [
    "### Create a processor\n",
    "Instantiate a [`SKLearnProcessor`](https://sagemaker.readthedocs.io/en/stable/frameworks/sklearn/sagemaker.sklearn.html#scikit-learn-processor) object before starting the SageMaker processing job. You specify the instance type to use in the job, as well as how many instances for distributed processing.\n",
    "\n",
    "Note how SageMaker maps your data to the local paths on the processing container's EBS volume:\n",
    "\n",
    "![](img/data-processing.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec206f04-99cc-4a4f-b622-1bb8fdac1460",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sklearn_processor = SKLearnProcessor(\n",
    "    framework_version=framework_version,\n",
    "    role=sm_role,\n",
    "    instance_type=processing_instance_type,\n",
    "    instance_count=processing_instance_count, \n",
    "    base_job_name='from-idea-to-prod-processing',\n",
    "    sagemaker_session=session,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b4253a9-66b4-4749-b356-28af6cf67141",
   "metadata": {},
   "source": [
    "#### Define processing inputs and outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1503fe76-cd04-4e5a-86a1-ea3cbb45f09d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "processing_inputs = [\n",
    "        ProcessingInput(\n",
    "            source=input_s3_url, \n",
    "            destination=\"/opt/ml/processing/input\",\n",
    "            s3_input_mode=\"File\",\n",
    "            s3_data_distribution_type=\"ShardedByS3Key\"\n",
    "        )\n",
    "    ]\n",
    "\n",
    "processing_outputs = [\n",
    "        ProcessingOutput(\n",
    "            output_name=\"train_data\", \n",
    "            source=\"/opt/ml/processing/output/train\",\n",
    "            destination=train_s3_url,\n",
    "        ),\n",
    "        ProcessingOutput(\n",
    "            output_name=\"validation_data\", \n",
    "            source=\"/opt/ml/processing/output/validation\", \n",
    "            destination=validation_s3_url\n",
    "        ),\n",
    "        ProcessingOutput(\n",
    "            output_name=\"test_data\", \n",
    "            source=\"/opt/ml/processing/output/test\", \n",
    "            destination=test_s3_url\n",
    "        ),\n",
    "        ProcessingOutput(\n",
    "            output_name=\"baseline_data\", \n",
    "            source=\"/opt/ml/processing/output/baseline\", \n",
    "            destination=baseline_s3_url\n",
    "        ),\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bedef440-c13a-4d23-90af-ef747aeef8bc",
   "metadata": {},
   "source": [
    "#### Start the SageMaker processing job\n",
    "\n",
    "<img src=\"\" alt=\"Time alert open medium\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab839f9-6194-4467-b9d9-b207c15eb729",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    sklearn_processor.run(\n",
    "        inputs=processing_inputs,\n",
    "        outputs=processing_outputs,\n",
    "        code='preprocessing.py',\n",
    "        wait=True,\n",
    "        experiment_config=experiment_config,\n",
    "        # arguments = ['arg1', 'arg2'],\n",
    "    )\n",
    "except botocore.exceptions.ClientError as e:\n",
    "    if e.response['Error']['Code'] == 'AccessDeniedException':\n",
    "        print(f\"Ignore AccessDeniedException: {e.response['Error']['Message']} because of the slow resource tag auto propagation\")\n",
    "    else:\n",
    "        raise e"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d9c8e0b",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert close\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d665ecf7-d836-40db-960e-ad8075c1ced4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# If you set wait to False in the previous code cell, wait until the job completes\n",
    "while sm.describe_processing_job(\n",
    "        ProcessingJobName=sklearn_processor._current_job_name\n",
    "    )[\"ProcessingJobStatus\"] != \"Completed\":\n",
    "    time.sleep(10)\n",
    "    print(f\"Wait until {sklearn_processor._current_job_name} completed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "795e7281-d03b-42ba-8ef4-b0b30f9de0fd",
   "metadata": {},
   "source": [
    "To wait for job completion you can also use `boto3` [waiters](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#waiters). For example:\n",
    "\n",
    "```python\n",
    "waiter = session.sagemaker_client.get_waiter('processing_job_completed_or_stopped')\n",
    "waiter.wait(ProcessingJobName=sklearn_processor._current_job_name)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad2ceb4b-ece3-40c6-8674-ce7b8db70a4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list the uploaded files\n",
    "!aws s3 ls {bucket_name}/{bucket_prefix} --recursive"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d25d03a-e988-4a73-acdd-874718d03a74",
   "metadata": {},
   "source": [
    "## Model training\n",
    "Follow the same approach and now run the model training as a [SageMaker training job](https://sagemaker.readthedocs.io/en/stable/overview.html#using-estimators)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfbeec6d-419d-4d88-a3f6-194440dd0124",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# get training container uri\n",
    "training_image = sagemaker.image_uris.retrieve(\"xgboost\", region=region, version=\"1.5-1\")\n",
    "\n",
    "print(training_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64034359-5102-4678-af52-aba7f21a4212",
   "metadata": {},
   "source": [
    "Define the data input channels for the training job. Set _train_ and _validation_ channels via the SageMaker SDK [`TrainingInput`](https://sagemaker.readthedocs.io/en/stable/api/utility/inputs.html#sagemaker.inputs.TrainingInput) class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82550351-737f-4a55-9578-0ae58ca1c156",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s3_input_train = sagemaker.inputs.TrainingInput(train_s3_url, content_type='csv')\n",
    "s3_input_validation = sagemaker.inputs.TrainingInput(validation_s3_url, content_type='csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86c9a303-928a-4053-9002-cd618cffa6dd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_instance_count = 1\n",
    "train_instance_type = \"ml.m5.xlarge\"\n",
    "\n",
    "# Define where the training job stores the model artifact\n",
    "output_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/output\"\n",
    "\n",
    "%store output_s3_url"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54807668-6a2d-44ac-9007-d945ec19198a",
   "metadata": {},
   "source": [
    "Instantiate an [Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) object and set algorithm's hyperparameters. Refer to [XGBoost hyperparameters](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost_hyperparameters.html) for more information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1a89cc-c841-4e5d-99b3-d4fe5b6435cc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Instantiate an XGBoost estimator object\n",
    "estimator = sagemaker.estimator.Estimator(\n",
    "    image_uri=training_image,  # XGBoost algorithm container\n",
    "    instance_type=train_instance_type,  # type of training instance\n",
    "    instance_count=train_instance_count,  # number of instances to be used\n",
    "    role=sm_role,  # IAM execution role to be used\n",
    "    max_run=20 * 60,  # Maximum allowed active runtime\n",
    "    # use_spot_instances=True,  # Use spot instances to reduce cost\n",
    "    # max_wait=30 * 60,  # Maximum clock time (including spot delays)\n",
    "    output_path=output_s3_url, # S3 location for saving the training result\n",
    "    sagemaker_session=session, # Session object which manages interactions with SageMaker API and AWS services\n",
    "    base_job_name=\"from-idea-to-prod-training\", # Prefix for training job name\n",
    ")\n",
    "\n",
    "# define its hyperparameters\n",
    "estimator.set_hyperparameters(\n",
    "    num_round=150, # the number of rounds to run the training\n",
    "    max_depth=3, # maximum depth of a tree\n",
    "    eta=0.5, # step size shrinkage used in updates to prevent overfitting\n",
    "    alpha=2.5, # L1 regularization term on weights\n",
    "    objective=\"binary:logistic\",\n",
    "    eval_metric=\"auc\", # evaluation metrics for validation data\n",
    "    subsample=0.8, # subsample ratio of the training instance\n",
    "    colsample_bytree=0.8, # subsample ratio of columns when constructing each tree\n",
    "    min_child_weight=3, # minimum sum of instance weight (hessian) needed in a child\n",
    "    early_stopping_rounds=10, # the model trains until the validation score stops improving\n",
    "    verbosity=1, # verbosity of printing messages\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3022218-4f87-4e59-8f58-9bab2a44bb96",
   "metadata": {},
   "source": [
    "Run the training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856cde4f-c1a3-436e-ad17-daeb77ebd60f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_inputs = {'train': s3_input_train, 'validation': s3_input_validation}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ca5a53b",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert open medium\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "287da7a9-c0ca-479e-8677-24176d91f95f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    run_suffix = strftime('%Y-%m-%M-%S', gmtime())\n",
    "    run_name = f\"container-training-{run_suffix}\"\n",
    "\n",
    "    with Run(experiment_name=experiment_name,\n",
    "             run_name=run_name,\n",
    "             run_display_name=\"container-training\",\n",
    "             sagemaker_session=session\n",
    "            ) as run:\n",
    "        \n",
    "        estimator.fit(\n",
    "            training_inputs,\n",
    "            wait=True,\n",
    "            logs=False,\n",
    "        ) \n",
    "except botocore.exceptions.ClientError as e:\n",
    "    if e.response['Error']['Code'] == 'AccessDeniedException':\n",
    "        print(f\"Ignore AccessDeniedException: {e.response['Error']['Message']} because of the slow resource tag auto propagation\")\n",
    "    else:\n",
    "        raise e"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acf59a4c",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert close\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4525ff7c-527d-4cd3-8708-83d516e47a02",
   "metadata": {},
   "source": [
    "### Reduce training job startup time with warm pools\n",
    "💡 Instead of using each time a new ephemeral computation cluster to train your models, you can keep your model training hardware instances warm after every job for a specified period. Refer to [Reduce ML Model Training Job startup time by up to 8x using SageMaker Training Managed Warm Pools](https://aws.amazon.com/about-aws/whats-new/2022/09/reduce-ml-model-training-job-startup-time-8x-sagemaker-training-managed-warm-pools/) for more details. If you opt to use warm pools, you are billed for the instances and EBS volumens for the duration of the keep-alive period. \n",
    "Refer to [ Train Using SageMaker Managed Warm Pools](https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html) in the Amazon SageMaker Developer Guide for details on training API."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b9af81-09c1-45b1-b59b-489a0d14fa76",
   "metadata": {},
   "source": [
    "### Output model performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "052e9640-3a0b-4a30-a897-ecd2ee102ca4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if estimator._current_job_name:\n",
    "    training_job_name = estimator._current_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a69d7a0c-8cbe-4fd1-af44-9677107b5229",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "metrics = None\n",
    "while not metrics:\n",
    "    metrics = sm.describe_training_job(\n",
    "        TrainingJobName=training_job_name\n",
    "        ).get(\"FinalMetricDataList\")\n",
    "\n",
    "    if not metrics:\n",
    "        print(f\"Training job {training_job_name} hasn't finished yet!\")\n",
    "        time.sleep(10)\n",
    "    \n",
    "train_auc = float([m['Value'] for m in metrics if m['MetricName'] == 'train:auc'][0])\n",
    "validate_auc = float([m['Value'] for m in metrics if m['MetricName'] == 'validation:auc'][0])\n",
    "\n",
    "print(f\"Train-auc:{train_auc:.2f}, Validate-auc:{validate_auc:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e83a2509-8e2c-4e1c-874f-215745be46e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%store training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05a41f7a-4f7e-413b-9f5f-a7557bf4060a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Print the S3 path to the model artifact:\n",
    "estimator.model_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aa43bd9-7cab-48da-b128-635b90c49971",
   "metadata": {},
   "source": [
    "## Validate model\n",
    "The training job saved a model in the specified location on Amazon S3.\n",
    "\n",
    "You can deploy the model as a [real-time endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html), which is just one [function call](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator.deploy), or create a [batch transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html) to predict a label for a large dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2b02426-0b29-4b42-a182-7e62ba951767",
   "metadata": {},
   "source": [
    "### Real-time inference\n",
    "To test [real-time inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) you create a real-time endpoint using the trained estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feb0d417-99e8-480d-baa4-7ab18d9b5427",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Real-time endpoint\n",
    "endpoint_name = f\"from-idea-to-prod-endpoint-{strftime('%d-%H-%M-%S', gmtime())}\"\n",
    "\n",
    "try:\n",
    "    predictor = estimator.deploy(\n",
    "        initial_instance_count=1,\n",
    "        instance_type=\"ml.m5.large\",\n",
    "        wait=False,  # Remember, predictor.predict() won't work until deployment finishes!\n",
    "        # Turn on data capture here, in case you want to experiment with monitoring:\n",
    "        data_capture_config=sagemaker.model_monitor.DataCaptureConfig(\n",
    "            enable_capture=True,\n",
    "            sampling_percentage=100,\n",
    "            destination_s3_uri=f\"s3://{bucket_name}/{bucket_prefix}/data-capture\",\n",
    "        ),\n",
    "        endpoint_name=endpoint_name,\n",
    "        serializer=sagemaker.serializers.CSVSerializer(),\n",
    "        deserializer=sagemaker.deserializers.CSVDeserializer(),\n",
    "    )\n",
    "except botocore.exceptions.ClientError as e:\n",
    "    if e.response['Error']['Code'] == 'AccessDeniedException':\n",
    "        print(f\"Ignore AccessDeniedException: {e.response['Error']['Message']} because of the slow resource tag auto propagation\")\n",
    "        predictor = sagemaker.predictor.Predictor(endpoint_name=endpoint_name,\n",
    "                                                  sagemaker_session=session,\n",
    "                                                  serializer=sagemaker.serializers.CSVSerializer(),\n",
    "                                                  deserializer=sagemaker.deserializers.CSVDeserializer(),\n",
    "                                                 )\n",
    "    else:\n",
    "        raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1714f00-78da-4d99-b665-77822581f50a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Wait until the endpoint has the status InService\n",
    "waiter = session.sagemaker_client.get_waiter('endpoint_in_service')\n",
    "waiter.wait(EndpointName=endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e2e7b68-7b53-420d-b716-c62a98bd0c5d",
   "metadata": {},
   "source": [
    "#### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe4d085-0c98-49db-bc61-b14e93e9f21e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!aws s3 cp $test_s3_url/test_x.csv tmp/test_x.csv\n",
    "!aws s3 cp $test_s3_url/test_y.csv tmp/test_y.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbfeb79b-014e-4784-a7bd-52f88b9957db",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_x = pd.read_csv(\"tmp/test_x.csv\", names=[f'{i}' for i in range(59)])\n",
    "test_y = pd.read_csv(\"tmp/test_y.csv\", names=['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5d420ff-2e3e-466a-b8e4-91a2f407f26d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictions = np.array(predictor.predict(test_x.values), dtype=float).squeeze()\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "772ebdad-ab4a-4cf1-bda5-82618d6d6261",
   "metadata": {},
   "source": [
    "#### Evaluate predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce21400a-6cbd-4817-869d-89e289fa165b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_results = pd.concat(\n",
    "    [\n",
    "        pd.Series(predictions, name=\"y_pred\", index=test_x.index),\n",
    "        test_x,\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "test_results.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b8ef37-c50d-465d-bf54-2b104ace6fe7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.crosstab(\n",
    "    index=test_y['y'].values,\n",
    "    columns=np.round(predictions), \n",
    "    rownames=['actuals'], \n",
    "    colnames=['predictions']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f994e5e0-b83f-4c39-b9fe-9e0c74b2b4f6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_auc = roc_auc_score(test_y, test_results[\"y_pred\"])\n",
    "print(f\"Test-auc: {test_auc:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b56d63-c998-4ef1-ada2-e85ed1987e1e",
   "metadata": {},
   "source": [
    "### Batch transform\n",
    "If you want to run a prediction on a large dataset or don't need a real-time endpoint, you can use SageMaker [batch-transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff1abd7-4ef7-4f27-a23a-1a254e823fa1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "transform_s3_url = f\"s3://{bucket_name}/{bucket_prefix}/transform\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7efc15f6-aa75-429f-9aed-c42572a06e7f",
   "metadata": {},
   "source": [
    "To create a transformer, use either option 1 or option 2."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc425f73-03f6-481d-a337-d4f457a26ff6",
   "metadata": {},
   "source": [
    "#### Option 1: create a batch transformer from the trained estimator\n",
    "You can use [`EstimatorBase.transformer()`](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase.transformer) to create a transformer for an estimator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6d7f94-b194-4b9a-9430-a1d5b1ec055b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = f\"from-idea-to-prod-transform-{strftime('%d-%H-%M-%S', gmtime())}\"\n",
    "\n",
    "transformer = estimator.transformer(\n",
    "    instance_count=1,\n",
    "    instance_type=train_instance_type,\n",
    "    accept=\"text/csv\",\n",
    "    role=sm_role,\n",
    "    output_path=transform_s3_url,\n",
    "    model_name=model_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d00e05ab-c043-40fa-90dd-24d6850b7130",
   "metadata": {},
   "source": [
    "Go to the section **Run transform job**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e49bf855-4b34-4ddb-ada6-736b2122ee18",
   "metadata": {},
   "source": [
    "#### Option 2: load a model from a training job\n",
    "Alternatively, you can load a model from a model artifact produced by a training job. You create a transformer with that model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e971d76c-8e06-4b94-9b25-6eaec8703622",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = session.create_model_from_job(\n",
    "    training_job_name=training_job_name, \n",
    "    name=model_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50993f97-dce2-4084-a1de-22658fc8b669",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = sagemaker.transformer.Transformer(\n",
    "    model_name=model,\n",
    "    instance_count=1,\n",
    "    instance_type=train_instance_type,\n",
    "    accept=\"text/csv\",\n",
    "    assemble_with=\"Line\",\n",
    "    output_path=transform_s3_url,\n",
    "    base_transform_job_name=\"from-idea-to-prod-trasform\",\n",
    "    sagemaker_session=session,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60af22db-848d-44e8-987e-c22aedbd5354",
   "metadata": {},
   "source": [
    "#### Run transform job\n",
    "\n",
    "<img src=\"\" alt=\"Time alert open medium\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b50377c-0196-41a3-9c8c-70ee2266c160",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "transform_job_name = f\"from-idea-to-prod-transform-{strftime('%d-%H-%M-%S', gmtime())}\"\n",
    "\n",
    "try:\n",
    "    run_suffix = strftime('%Y-%m-%M-%S', gmtime())\n",
    "    run_name = f\"batch-transform-{run_suffix}\"\n",
    "\n",
    "    with Run(experiment_name=experiment_name,\n",
    "             run_name=run_name,\n",
    "             run_display_name=\"batch-transform\",\n",
    "             sagemaker_session=session\n",
    "            ) as run:\n",
    "        transformer.transform(    \n",
    "            data=f\"{test_s3_url}/test_x.csv\",\n",
    "            content_type=\"text/csv\",\n",
    "            split_type=\"Line\", \n",
    "            job_name=transform_job_name,\n",
    "            wait=True,\n",
    "            # experiment_config=experiment_config,\n",
    "        )\n",
    "except botocore.exceptions.ClientError as e:\n",
    "    if e.response['Error']['Code'] == 'AccessDeniedException':\n",
    "        print(f\"Ignore AccessDeniedException: {e.response['Error']['Message']} because of the slow resource tag auto propagation\")\n",
    "    else:\n",
    "        raise e"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edcfde81",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert close\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f11a902-1306-430f-8925-2ddd17741609",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "while sm.describe_transform_job(\n",
    "        TransformJobName=transformer._current_job_name\n",
    "    )[\"TransformJobStatus\"] != \"Completed\":\n",
    "    time.sleep(10)\n",
    "    print(f\"Wait until {transformer._current_job_name} completed\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c80898-b328-4413-bee2-34faf29f5f22",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "transformer.output_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "672524d0-d056-4a1d-b80a-b4e4822c8290",
   "metadata": {},
   "source": [
    "#### Evaluate predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc287aa8-de90-4de3-b59c-d7fc0e9aad2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!aws s3 ls {transformer.output_path}/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec638017-8bc3-498a-8b0d-26b5e85c0cde",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!aws s3 cp {transformer.output_path}/test_x.csv.out tmp/predictions.csv\n",
    "!aws s3 cp $test_s3_url/test_y.csv tmp/test_y.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d5e4e16-4739-437c-8237-2083aae150be",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictions = pd.read_csv(\"tmp/predictions.csv\", names=[\"y_prob\"])\n",
    "test_y = pd.read_csv(\"tmp/test_y.csv\", names=['y'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54f1497b-2723-4e37-a1e0-75117dcc91af",
   "metadata": {},
   "source": [
    "#### Crosstab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5e150c-1bf7-414c-af9d-54ed108117aa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.crosstab(\n",
    "    index=test_y['y'].values,\n",
    "    columns=np.array(np.round(predictions), dtype=float).squeeze(), \n",
    "    rownames=['actuals'], \n",
    "    colnames=['predictions']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd11aa9b-9bcc-4e70-a508-93f02781fa73",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_auc = roc_auc_score(test_y, predictions)\n",
    "print(f\"Test-auc: {test_auc:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d358952-5b3a-4690-b118-1677bd80fd5c",
   "metadata": {},
   "source": [
    "#### ROC curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eabaa6a7-08a2-4e21-95e2-17fea4c6e692",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn import metrics\n",
    "from sklearn.metrics import RocCurveDisplay\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "fpr, tpr, thresholds = metrics.roc_curve(test_y, predictions)\n",
    "roc_auc = metrics.auc(fpr, tpr)\n",
    "display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,estimator_name='Holdout/Test Data - ROC curve')\n",
    "display.plot()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34c19463-4fb0-43bd-bb33-9f3a174c060e",
   "metadata": {},
   "source": [
    "#### Confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449ffe9b-1b7d-44d9-a9ce-c8d68c97de7e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import ConfusionMatrixDisplay\n",
    "\n",
    "cm = confusion_matrix(test_y, np.round(predictions))\n",
    "f = sns.heatmap(cm, annot=True, fmt='d')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b90e2e7f-69fa-409b-a36f-341842d50146",
   "metadata": {},
   "source": [
    "#### Precision-recall curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c36302d4-5855-43cd-9d47-e5f2d50bfaf5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import PrecisionRecallDisplay\n",
    "\n",
    "prec, recall, _ = precision_recall_curve(test_y, predictions)\n",
    "average_precision= metrics.average_precision_score(test_y, predictions)\n",
    "pr_display = PrecisionRecallDisplay(precision=prec, recall=recall, average_precision=average_precision, estimator_name='Holdout/Test Data - AUPRC curve')\n",
    "pr_display.plot()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d3f53ac-3770-44e8-8460-797ae33e6642",
   "metadata": {},
   "source": [
    "### Save charts to the experiment run\n",
    "You can use the [`experiments.load_run()`](https://sagemaker.readthedocs.io/en/stable/experiments/sagemaker.experiments.html#sagemaker.experiments.load_run) method to load an existing run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47261776-0ac2-48d1-9bb1-8ad16cc200fd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "title_suffix = strftime('%Y-%m-%M-%S', gmtime())\n",
    "\n",
    "with load_run(experiment_name=experiment_name, run_name=run_name) as run:\n",
    "    print(run.experiment_config)\n",
    "    run.log_confusion_matrix(y_true=test_y['y'].values,\n",
    "                             y_pred=np.array(np.round(predictions), dtype=float).squeeze(), \n",
    "                             title=f\"confusion-matrix-{title_suffix}\")\n",
    "    run.log_roc_curve(y_true=test_y['y'].values, \n",
    "                      y_score=predictions['y_prob'].values, \n",
    "                      title=f\"roc-curve2-{title_suffix}\")\n",
    "    run.log_precision_recall(y_true=test_y['y'].values,\n",
    "                             predicted_probabilities=predictions['y_prob'].values,\n",
    "                             positive_label=1,\n",
    "                             title=f\"precision-recall-{title_suffix}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de6ee7d6-e8d7-40aa-bae6-8db6edbcc10d",
   "metadata": {},
   "source": [
    "## Explore experiments and runs with Studio UX\n",
    "You can see all logged metrics, parameters, and artifacts in Studio UX in **SageMaker Home** > **Experiments** widget.\n",
    "\n",
    "For example, click on your experiment name you used in this notebook:\n",
    "\n",
    "![](img/experiment-and-runs.png)\n",
    "\n",
    "You see runs which you created in this notebook:\n",
    "\n",
    "![](img/runs-02.png)\n",
    "\n",
    "Select `batch-transform-<timestamp>` run and choose **Charts** in the **Overview** section on the left pane. You see the three added charts in the run:\n",
    "\n",
    "![](img/run-charts.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6949ec51-5f38-4432-9a5d-ed4e5350644f",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d8a7389-9af4-4a0b-8ea7-376db0c2ed02",
   "metadata": {},
   "source": [
    "## Optional: Hyperparameter optimization (HPO)\n",
    "It takes about 20 minutes to run this section. The section is optional and you don't need to run it to continue with other notebooks. You can navigate directly to step 3 [notebook](03-sagemaker-pipeline.ipynb). If you would like to perform a model A/B test in **Additional topics** sections, you can execute this part to produce an alternative model.\n",
    "\n",
    "[Amazon SageMaker automatic model tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html), also called hyperparameter optimization (HPO), finds the best performing model against a defined objective metric by running many training jobs on the dataset using the algorithm and ranges of hyperparameters that you specify. SageMaker HPT supports random search, bayesian optimization, and [hyperband](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html) as tuning strategies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5230ec51-bab2-40ca-a5ce-3a31590d47a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import required HPO objects\n",
    "from sagemaker.tuner import (\n",
    "    CategoricalParameter,\n",
    "    ContinuousParameter,\n",
    "    HyperparameterTuner,\n",
    "    IntegerParameter,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a31feb51-0ca9-4517-8c1a-02415970a46b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up hyperparameter ranges\n",
    "hp_ranges = {\n",
    "    \"min_child_weight\": ContinuousParameter(1, 10),\n",
    "    \"max_depth\": IntegerParameter(1, 10),\n",
    "    \"alpha\": ContinuousParameter(0, 5),\n",
    "    \"eta\": ContinuousParameter(0, 1),\n",
    "    \"colsample_bytree\": ContinuousParameter(0, 1),\n",
    "    \"gamma\": ContinuousParameter(0, 10)\n",
    "    \n",
    "}\n",
    "\n",
    "# set up the objective metric\n",
    "objective = \"validation:auc\"\n",
    "\n",
    "# instantiate a HPO object\n",
    "tuner = HyperparameterTuner(\n",
    "    estimator=estimator,  # the SageMaker estimator object\n",
    "    hyperparameter_ranges=hp_ranges,  # the range of hyperparameters\n",
    "    max_jobs=30,  # total number of HPO jobs\n",
    "    max_parallel_jobs=3,  # how many HPO jobs can run in parallel\n",
    "    strategy=\"Bayesian\",  # the internal optimization strategy of HPO\n",
    "    objective_metric_name=objective,  # the objective metric to be used for HPO\n",
    "    objective_type=\"Maximize\",  # maximize or minimize the objective metric\n",
    "    base_tuning_job_name=\"from-idea-to-prod-hpo\",\n",
    "    early_stopping_type=\"Auto\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aab138e-5629-4a9d-830a-efb425a375bb",
   "metadata": {},
   "source": [
    "Now run the HPO job. It takes about 10 minutes to complete. \n",
    "\n",
    "<div class=\"alert alert-info\"> 💡 Note, that the HPO job creates its own experiment to track each training job with a specific set of hyperparameters as a separate run.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9752f3de",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert open medium\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "612a815f-c9f9-4e8e-9afb-ffd3b98a0f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tuner.fit(\n",
    "    {\"train\": s3_input_train, \"validation\": s3_input_validation},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0697a30f",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert close\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d60e7673-1d4f-4632-b5cc-c0745cb83260",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"HPO job status: {sm.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=tuner.latest_tuning_job.job_name)['HyperParameterTuningJobStatus']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8eb8c40",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert open medium\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc88733a-0711-4248-9acd-c7d78e4f1830",
   "metadata": {},
   "outputs": [],
   "source": [
    "hpo_predictor = tuner.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.m5.large\",\n",
    "    serializer=sagemaker.serializers.CSVSerializer(),\n",
    "    deserializer=sagemaker.deserializers.CSVDeserializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c7ab6f",
   "metadata": {},
   "source": [
    "<img src=\"\" alt=\"Time alert close\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3337f1c6-f211-4054-907f-ef0ee17683f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "hpo_predictions = np.array(hpo_predictor.predict(test_x.values), dtype=float).squeeze()\n",
    "print(hpo_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d77cc440-3828-4677-b19b-ede1edcb3903",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.crosstab(\n",
    "    index=test_y['y'].values,\n",
    "    columns=np.round(hpo_predictions), \n",
    "    rownames=['actuals'], \n",
    "    colnames=['predictions']\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c06cc7b-81f9-42d6-9445-c35b0409f486",
   "metadata": {},
   "source": [
    "There is no any material improvements for the model metrics. It can indicate, that the XGBoost model is already at it's limit. You might want to explore other model types to improve the prediction accuracy for this use case."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64fd8435-9c7e-4b4f-87b2-13ee830d19e3",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5493780-c392-4432-a34b-75f27ce984a5",
   "metadata": {},
   "source": [
    "## Clean-up\n",
    "To avoid charges, remove the hosted endpoint you created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2de13080-5527-4de4-aacb-3eb50af0f92f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictor.delete_endpoint(delete_endpoint_config=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "040e95d7-b906-4a7c-90cb-4d73cefbf617",
   "metadata": {},
   "outputs": [],
   "source": [
    "# run if you created a tuned predictor after HPO\n",
    "hpo_predictor.delete_endpoint(delete_endpoint_config=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34f720e0-e8d5-4cb5-93a6-6a4f2dafd523",
   "metadata": {},
   "source": [
    "## Continue with the step 3\n",
    "open the step 3 [notebook](03-sagemaker-pipeline.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62b623d7-4b16-4c6f-96a8-ad92a77b1601",
   "metadata": {},
   "source": [
    "## Further development ideas for your real-world projects\n",
    "- Track, organize, and compare all your model training runs using [SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html)\n",
    "- Use [Amazon SageMaker Data Wrangler](https://aws.amazon.com/sagemaker/data-wrangler/) for creating a no-code or low-code visual data processing and feature engineering flow. Refer to this hands-on tutorial: [Prepare Training Data for Machine Learning with Minimal Code](https://aws.amazon.com/getting-started/hands-on/machine-learning-tutorial-prepare-data-with-minimal-code/)\n",
    "- Try no-code [SageMaker Canvas](https://docs.aws.amazon.com/sagemaker/latest/dg/canvas.html) on your data to perform analysis and use automated ML to build models and generate predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ba1fc9d-7a52-4d2f-b51e-0f9422119389",
   "metadata": {},
   "source": [
    "## Additional resources\n",
    "- [Using Docker containers with SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers.html)\n",
    "- [How to create and use a custom SageMaker container: SageMaker hands-on workshop](https://sagemaker-workshop.com/custom/containers.html)\n",
    "- [Amazon SageMaker Immersion Day](https://catalog.us-east-1.prod.workshops.aws/workshops/63069e26-921c-4ce1-9cc7-dd882ff62575/en-US)\n",
    "- [Targeting Direct Marketing with Amazon SageMaker XGBoost](https://github.com/aws-samples/amazon-sagemaker-immersion-day/blob/master/processing_xgboost.ipynb)\n",
    "- [Train a Machine Learning Model](https://aws.amazon.com/getting-started/hands-on/machine-learning-tutorial-train-a-model/)\n",
    "- [Deploy a Machine Learning Model to a Real-Time Inference Endpoint](https://aws.amazon.com/getting-started/hands-on/machine-learning-tutorial-deploy-model-to-real-time-inference-endpoint/)\n",
    "- [Amazon SageMaker 101 Workshop](https://catalog.us-east-1.prod.workshops.aws/workshops/0c6b8a23-b837-4e0f-b2e2-4a3ffd7d645b/en-US)\n",
    "- [Amazon SageMaker 101 Workshop code repository](https://github.com/aws-samples/sagemaker-101-workshop)\n",
    "- [Amazon SageMaker with XGBoost and Hyperparameter Tuning for Direct Marketing predictions](https://github.com/aws-samples/sagemaker-101-workshop/blob/main/builtin_algorithm_hpo_tabular/SageMaker%20XGBoost%20HPO.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d7b222f-ee83-4c87-9019-cebcc7e56627",
   "metadata": {},
   "source": [
    "# Shutdown kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae0b1f02-44b8-4749-a5e2-ce415fff7f78",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "725f9d18-a550-46f6-9de8-9aeeed8171fd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "availableInstances": [
   {
    "_defaultOrder": 0,
    "_isFastLaunch": true,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 4,
    "name": "ml.t3.medium",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 1,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 8,
    "name": "ml.t3.large",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 2,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 16,
    "name": "ml.t3.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 3,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 32,
    "name": "ml.t3.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 4,
    "_isFastLaunch": true,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 8,
    "name": "ml.m5.large",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 5,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 16,
    "name": "ml.m5.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 6,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 32,
    "name": "ml.m5.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 7,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 64,
    "name": "ml.m5.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 8,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 128,
    "name": "ml.m5.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 9,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 192,
    "name": "ml.m5.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 10,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 256,
    "name": "ml.m5.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 11,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 384,
    "name": "ml.m5.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 12,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 8,
    "name": "ml.m5d.large",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 13,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 16,
    "name": "ml.m5d.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 14,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 32,
    "name": "ml.m5d.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 15,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 64,
    "name": "ml.m5d.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 16,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 128,
    "name": "ml.m5d.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 17,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 192,
    "name": "ml.m5d.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 18,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 256,
    "name": "ml.m5d.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 19,
    "_isFastLaunch": false,
    "category": "General purpose",
    "gpuNum": 0,
    "memoryGiB": 384,
    "name": "ml.m5d.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 20,
    "_isFastLaunch": true,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 4,
    "name": "ml.c5.large",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 21,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 8,
    "name": "ml.c5.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 22,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 16,
    "name": "ml.c5.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 23,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 32,
    "name": "ml.c5.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 24,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 72,
    "name": "ml.c5.9xlarge",
    "vcpuNum": 36
   },
   {
    "_defaultOrder": 25,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 96,
    "name": "ml.c5.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 26,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 144,
    "name": "ml.c5.18xlarge",
    "vcpuNum": 72
   },
   {
    "_defaultOrder": 27,
    "_isFastLaunch": false,
    "category": "Compute optimized",
    "gpuNum": 0,
    "memoryGiB": 192,
    "name": "ml.c5.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 28,
    "_isFastLaunch": true,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 16,
    "name": "ml.g4dn.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 29,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 32,
    "name": "ml.g4dn.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 30,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 64,
    "name": "ml.g4dn.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 31,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 128,
    "name": "ml.g4dn.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 32,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 4,
    "memoryGiB": 192,
    "name": "ml.g4dn.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 33,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 256,
    "name": "ml.g4dn.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 34,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 61,
    "name": "ml.p3.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 35,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 4,
    "memoryGiB": 244,
    "name": "ml.p3.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 36,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 8,
    "memoryGiB": 488,
    "name": "ml.p3.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 37,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 8,
    "memoryGiB": 768,
    "name": "ml.p3dn.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 38,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 16,
    "name": "ml.r5.large",
    "vcpuNum": 2
   },
   {
    "_defaultOrder": 39,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 32,
    "name": "ml.r5.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 40,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 64,
    "name": "ml.r5.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 41,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 128,
    "name": "ml.r5.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 42,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 256,
    "name": "ml.r5.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 43,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 384,
    "name": "ml.r5.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 44,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 512,
    "name": "ml.r5.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 45,
    "_isFastLaunch": false,
    "category": "Memory Optimized",
    "gpuNum": 0,
    "memoryGiB": 768,
    "name": "ml.r5.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 46,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 16,
    "name": "ml.g5.xlarge",
    "vcpuNum": 4
   },
   {
    "_defaultOrder": 47,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 32,
    "name": "ml.g5.2xlarge",
    "vcpuNum": 8
   },
   {
    "_defaultOrder": 48,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 64,
    "name": "ml.g5.4xlarge",
    "vcpuNum": 16
   },
   {
    "_defaultOrder": 49,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 128,
    "name": "ml.g5.8xlarge",
    "vcpuNum": 32
   },
   {
    "_defaultOrder": 50,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 1,
    "memoryGiB": 256,
    "name": "ml.g5.16xlarge",
    "vcpuNum": 64
   },
   {
    "_defaultOrder": 51,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 4,
    "memoryGiB": 192,
    "name": "ml.g5.12xlarge",
    "vcpuNum": 48
   },
   {
    "_defaultOrder": 52,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 4,
    "memoryGiB": 384,
    "name": "ml.g5.24xlarge",
    "vcpuNum": 96
   },
   {
    "_defaultOrder": 53,
    "_isFastLaunch": false,
    "category": "Accelerated computing",
    "gpuNum": 8,
    "memoryGiB": 768,
    "name": "ml.g5.48xlarge",
    "vcpuNum": 192
   }
  ],
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}