{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Managed Spot Training for XGBoost\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "---" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "This notebook shows usage of SageMaker Managed Spot infrastructure for XGBoost training. Below we show how Spot instances can be used for the 'algorithm mode' and 'script mode' training methods with the XGBoost container. \n", "\n", "[Managed Spot Training](https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html) uses Amazon EC2 Spot instance to run training jobs instead of on-demand instances. You can specify which training jobs use spot instances and a stopping condition that specifies how long Amazon SageMaker waits for a job to run using Amazon EC2 Spot instances." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "This notebook was tested in Amazon SageMaker Studio on a ml.t3.medium instance with Python 3 (Data Science) kernel.\n", "\n", "In this notebook we will perform XGBoost training as described [here](). See the original notebook for more details on the data. " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Setup variables and define functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade sagemaker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "import io\n", "import os\n", "import boto3\n", "import sagemaker\n", "\n", "role = sagemaker.get_execution_role()\n", "region = boto3.Session().region_name\n", "\n", "# S3 bucket for saving code and model artifacts.\n", "# Feel free to specify a different bucket here if you wish.\n", "bucket = sagemaker.Session().default_bucket()\n", "prefix = \"sagemaker/DEMO-xgboost-spot\"\n", "# customize to your bucket where you have would like to store the data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Fetching the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "s3 = boto3.client(\"s3\")\n", "# Load the dataset\n", "FILE_DATA = \"abalone\"\n", "s3.download_file(\n", " f\"sagemaker-example-files-prod-{region}\",\n", " f\"datasets/tabular/uci_abalone/abalone.libsvm\",\n", " FILE_DATA,\n", ")\n", "sagemaker.Session().upload_data(FILE_DATA, bucket=bucket, key_prefix=prefix + \"/train\")\n", "sagemaker.Session().upload_data(FILE_DATA, bucket=bucket, key_prefix=prefix + \"/validation\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Obtaining the latest XGBoost container\n", "We obtain the new container by specifying the framework version (1.7-1). This version specifies the upstream XGBoost framework version (1.7) and an additional SageMaker version (1). If you have an existing XGBoost workflow based on the previous (1.0-1, 1.2-2, 1.3-1 or 1.5-1) container, this would be the only change necessary to get the same workflow working with the new container." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "container = sagemaker.image_uris.retrieve(\"xgboost\", region, \"1.7-1\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Training the XGBoost model\n", "\n", "After setting training parameters, we kick off training, and poll for status until training is completed, which in this example, takes few minutes.\n", "\n", "To run our training script on SageMaker, we construct a sagemaker.xgboost.estimator.XGBoost estimator, which accepts several constructor arguments:\n", "\n", "* __entry_point__: The path to the Python script SageMaker runs for training and prediction.\n", "* __role__: Role ARN\n", "* __hyperparameters__: A dictionary passed to the train function as hyperparameters.\n", "* __train_instance_type__ *(optional)*: The type of SageMaker instances for training. __Note__: This particular mode does not currently support training on GPU instance types.\n", "* __sagemaker_session__ *(optional)*: The session used to train on Sagemaker." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {\n", " \"max_depth\": \"5\",\n", " \"eta\": \"0.2\",\n", " \"gamma\": \"4\",\n", " \"min_child_weight\": \"6\",\n", " \"subsample\": \"0.7\",\n", " \"objective\": \"reg:squarederror\",\n", " \"num_round\": \"50\",\n", " \"verbosity\": \"2\",\n", "}\n", "\n", "instance_type = \"ml.m5.4xlarge\"\n", "output_path = \"s3://{}/{}/{}/output\".format(bucket, prefix, \"abalone-xgb\")\n", "content_type = \"libsvm\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If Spot instances are used, the training job can be interrupted, causing it to take longer to start or finish. If a training job is interrupted, a checkpointed snapshot can be used to resume from a previously saved point and can save training time (and cost).\n", "\n", "To enable checkpointing for Managed Spot Training using SageMaker XGBoost we need to configure three things: \n", "\n", "1. Enable the `train_use_spot_instances` constructor arg - a simple self-explanatory boolean. \n", "\n", "2. Set the `train_max_wait constructor` arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured. \n", "\n", "3. Setup a `checkpoint_s3_uri` constructor arg - this arg will tell SageMaker an S3 location where to save checkpoints. While not strictly necessary, checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.\n", "\n", "Feel free to toggle the `train_use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. \"On Demand\") infrastructure.\n", "\n", "Note that `train_max_wait` can be set if and only if `train_use_spot_instances` is enabled and must be greater than or equal to `train_max_run`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "from sagemaker.inputs import TrainingInput\n", "\n", "job_name = \"DEMO-xgboost-spot-\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "print(\"Training job\", job_name)\n", "\n", "use_spot_instances = True\n", "max_run = 3600\n", "max_wait = 7200 if use_spot_instances else None\n", "checkpoint_s3_uri = (\n", " \"s3://{}/{}/checkpoints/{}\".format(bucket, prefix, job_name) if use_spot_instances else None\n", ")\n", "print(\"Checkpoint path:\", checkpoint_s3_uri)\n", "\n", "estimator = sagemaker.estimator.Estimator(\n", " container,\n", " role,\n", " hyperparameters=hyperparameters,\n", " instance_count=1,\n", " instance_type=instance_type,\n", " volume_size=5, # 5 GB\n", " output_path=output_path,\n", " sagemaker_session=sagemaker.Session(),\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait,\n", " checkpoint_s3_uri=checkpoint_s3_uri,\n", ")\n", "\n", "train_input = TrainingInput(\n", " s3_data=\"s3://{}/{}/{}\".format(bucket, prefix, \"train\"), content_type=\"libsvm\"\n", ")\n", "estimator.fit({\"train\": train_input}, job_name=job_name)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Savings\n", "Towards the end of the job you should see two lines of output printed:\n", "\n", "- `Training seconds: X` : This is the actual compute-time your training job spent\n", "- `Billable seconds: Y` : This is the time you will be billed for after Spot discounting is applied.\n", "\n", "If you enabled the `train_use_spot_instances`, then you should see a notable difference between `X` and `Y` signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line:\n", "- `Managed Spot Training savings: (1-Y/X)*100 %`" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Train with Automatic Model Tuning ([HPO](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html)) and Spot Training enabled \n", "***\n", "You could also train with Amazon SageMaker Automatic Model Tuning. AMT, also known as hyperparameter tuning, finds the best version of a model by running many training jobs on your dataset using the algorithm and ranges of hyperparameters that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by a metric that you choose. We will use a [HyperparameterTuner](https://sagemaker.readthedocs.io/en/stable/api/training/tuner.html) object to interact with Amazon SageMaker hyperparameter tuning APIs.\n", " \n", "The code sample below shows you how to use the HyperParameterTuner and Spot Training together.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import ContinuousParameter, IntegerParameter\n", "from sagemaker.utils import name_from_base\n", "from sagemaker.tuner import HyperparameterTuner\n", "\n", "\n", "# You can select from the hyperparameters supported by the model, and configure ranges of values to be searched for training the optimal model.(https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html)\n", "hyperparameter_ranges = {\n", " \"max_depth\": IntegerParameter(0, 10, scaling_type=\"Auto\"),\n", " \"num_round\": IntegerParameter(1, 4000, scaling_type=\"Auto\"),\n", " \"alpha\": ContinuousParameter(0, 2, scaling_type=\"Auto\"),\n", " \"subsample\": ContinuousParameter(0.5, 1, scaling_type=\"Auto\"),\n", " \"min_child_weight\": ContinuousParameter(0, 120, scaling_type=\"Auto\"),\n", " \"gamma\": ContinuousParameter(0, 5, scaling_type=\"Auto\"),\n", " \"eta\": ContinuousParameter(0.1, 0.5, scaling_type=\"Auto\"),\n", "}\n", "\n", "# Increase the total number of training jobs run by AMT, for increased accuracy (and training time).\n", "max_jobs = 6\n", "# Change parallel training jobs run by AMT to reduce total training time, constrained by your account limits.\n", "# if max_jobs=max_parallel_jobs then Bayesian search turns to Random.\n", "max_parallel_jobs = 2\n", "\n", "hp_tuner = HyperparameterTuner(\n", " estimator,\n", " \"validation:rmse\",\n", " hyperparameter_ranges,\n", " max_jobs=max_jobs,\n", " max_parallel_jobs=max_parallel_jobs,\n", " objective_type=\"Minimize\",\n", " base_tuning_job_name=job_name,\n", ")\n", "\n", "# Launch a SageMaker Tuning job to search for the best hyperparameters\n", "# In this case, the tuner requires a `validation` channel to emit the validation:rmse metric.\n", "# Since we only created a `train` channel, we re-use it for validation.\n", "hp_tuner.fit({\"train\": train_input, \"validation\": train_input})" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Enabling checkpointing for script mode\n", "\n", "An additional mode of operation is to run customizable scripts as part of the training and inference jobs. See [this notebook](./xgboost_abalone_dist_script_mode.ipynb) for details on how to setup script mode. \n", "\n", "Here we highlight the specific changes that would enable checkpointing and use Spot instances. \n", "\n", "Checkpointing in the framework mode for SageMaker XGBoost can be performed using two convenient functions: \n", "\n", "- `save_checkpoint`: this returns a callback function that performs checkpointing of the model for each round. This is passed to XGBoost as part of the [`callbacks`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.train) argument. \n", "\n", "- `load_checkpoint`: This is used to load existing checkpoints to ensure training resumes from where it previously stopped. \n", "\n", "Both functions take the checkpoint directory as input, which in the below example is set to `/opt/ml/checkpoints`. \n", "The primary arguments that change for the `xgb.train` call are \n", "\n", "1. `xgb_model`: This refers to the previous checkpoint (saved from a previously run partial job) obtained by `load_checkpoint`. This would be `None` if no previous checkpoint is available. \n", "2. `callbacks`: This contains a function that performs the checkpointing\n", "\n", "Updated script looks like the following. \n", "\n", "---------\n", "```\n", "CHECKPOINTS_DIR = '/opt/ml/checkpoints' # default location for Checkpoints\n", "callbacks = [save_checkpoint(CHECKPOINTS_DIR)]\n", "prev_checkpoint, n_iterations_prev_run = load_checkpoint(CHECKPOINTS_DIR)\n", "bst = xgb.train(\n", " params=train_hp,\n", " dtrain=dtrain,\n", " evals=watchlist,\n", " num_boost_round=(args.num_round - n_iterations_prev_run),\n", " xgb_model=prev_checkpoint,\n", " callbacks=callbacks\n", " )\n", "```" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Using the SageMaker XGBoost Estimator\n", "\n", "The XGBoost estimator class in the SageMaker Python SDK allows us to run that script as a training job on the Amazon SageMaker managed training infrastructure. We’ll also pass the estimator our IAM role, the type of instance we want to use, and a dictionary of the hyperparameters that we want to pass to our script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.xgboost.estimator import XGBoost\n", "\n", "job_name = \"DEMO-xgboost-regression-\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "print(\"Training job\", job_name)\n", "checkpoint_s3_uri = (\n", " \"s3://{}/{}/checkpoints/{}\".format(bucket, prefix, job_name) if use_spot_instances else None\n", ")\n", "print(\"Checkpoint path:\", checkpoint_s3_uri)\n", "\n", "xgb_script_mode_estimator = XGBoost(\n", " entry_point=\"abalone.py\",\n", " hyperparameters=hyperparameters,\n", " role=role,\n", " instance_count=1,\n", " instance_type=instance_type,\n", " framework_version=\"1.7-1\",\n", " output_path=\"s3://{}/{}/{}/output\".format(bucket, prefix, \"xgboost-script-mode\"),\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait,\n", " checkpoint_s3_uri=checkpoint_s3_uri,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Training is as simple as calling `fit` on the Estimator. This will start a SageMaker Training job that will download the data, invoke the entry point code (in the provided script file), and save any model artifacts that the script creates. In this case, the script requires a `train` and a `validation` channel. Since we only created a `train` channel, we re-use it for validation. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xgb_script_mode_estimator.fit({\"train\": train_input, \"validation\": train_input}, job_name=job_name)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As previously stated, the estimator can also be passed to the [HyperparameterTuner](https://sagemaker.readthedocs.io/en/stable/api/training/tuner.html) object to interact with the Amazon SageMaker hyperparameter tuning APIs and create a HyperParameter Tuning Job. Hyper Parameters are automatically tuned which in most cases results in a more accurate model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hp_tuner = HyperparameterTuner(\n", " xgb_script_mode_estimator,\n", " \"validation:rmse\",\n", " hyperparameter_ranges,\n", " max_jobs=max_jobs,\n", " max_parallel_jobs=max_parallel_jobs,\n", " objective_type=\"Minimize\",\n", " base_tuning_job_name=job_name,\n", ")\n", "\n", "# Launch a SageMaker Tuning job to search for the best hyperparameters\n", "# In this case, the tuner requires a `validation` channel to emit the validation:rmse metric.\n", "# Since we only created a `train` channel, we re-use it for validation.\n", "hp_tuner.fit({\"train\": train_input, \"validation\": train_input})" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|xgboost_abalone|xgboost_managed_spot_training.ipynb)\n" ] } ], "metadata": { "anaconda-cloud": {}, "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 3.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }