{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiclass classification on Spark with Amazon SageMaker XGBoost algorithm\n", "_**Single machine and distributed training on Spark for multiclass classification with Amazon SageMaker XGBoost algorithm**_\n", "\n", "---\n", "\n", "## Introduction\n", "\n", "\n", "This notebook demonstrates the use of Amazon SageMaker’s implementation of the XGBoost algorithm to train and host a multiclass classification model using the sagemaker-spark SDK.\n", "\n", "---\n", "\n", "## Download Dataset\n", "\n", "For the purposes of this example we are downloading a dataset that has already been converted to libsvm format." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "region = \"us-east-1\"\n", "\n", "training_data = spark.read.format(\"libsvm\").option(\"numFeatures\", \"784\").load(\"s3a://sagemaker-sample-data-{}/spark/mnist/train/\".format(region))\n", "\n", "test_data = spark.read.format(\"libsvm\").option(\"numFeatures\", \"784\").load(\"s3a://sagemaker-sample-data-{}/spark/mnist/test/\".format(region))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create And Invoke Model\n", "\n", "The IAM role specified in `iam_role` is passed to the containers SageMaker uses for model hosting allowing them to do things like publish CloudWatch metrics and download data from S3. If you are unsure of which policies to add to this role try adding the managed `AmazonSageMakerFullAccess` and scoping down permissions from there if needed.\n", "\n", "Takes ~10-20 minutes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker_pyspark import IAMRole\n", "from sagemaker_pyspark.algorithms import XGBoostSageMakerEstimator\n", "\n", "iam_role = \"name_of_your_iam_role\"\n", "\n", "xgboost_estimator = XGBoostSageMakerEstimator(\n", " trainingInstanceType=\"ml.m4.xlarge\",\n", " trainingInstanceCount=1,\n", " endpointInstanceType=\"ml.m4.xlarge\",\n", " endpointInitialInstanceCount=1,\n", " sagemakerRole=IAMRole(iam_role))\n", "\n", "xgboost_estimator.setNumRound(25) # Set number of trees to use\n", "xgboost_estimator.setNumClasses(10) # MNIST contains digits 0-9\n", "xgboost_estimator.setObjective('multi:softmax') # Set XGBoost objective to multi-class classification w/ SoftMax\n", "\n", "xgboost_model = xgboost_estimator.fit(training_data)\n", "\n", "transformed_data = xgboost_model.transform(test_data.limit(5)) # Score first 5 rows of test data\n", "transformed_data.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create And Invoke Model From An Existing Endpoint\n", "\n", "In the last step we saw how you can create and train a model then invoke it from the model object. Here we create the model object from an existing SageMaker endpoint and use invoke it for scoring on the same test data. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+\n", "|label| features|prediction|\n", "+-----+--------------------+----------+\n", "| 7.0|(784,[202,203,204...| 7.0|\n", "| 2.0|(784,[94,95,96,97...| 2.0|\n", "| 1.0|(784,[128,129,130...| 1.0|\n", "| 0.0|(784,[124,125,126...| 0.0|\n", "| 4.0|(784,[150,151,159...| 4.0|\n", "+-----+--------------------+----------+" ] } ], "source": [ "from sagemaker_pyspark import SageMakerModel, EndpointCreationPolicy\n", "from sagemaker_pyspark.transformation.serializers import LibSVMRequestRowSerializer\n", "from sagemaker_pyspark.transformation.deserializers import XGBoostCSVRowDeserializer\n", "\n", "my_endpoint = xgboost_model.endpointName # Get endpoint name of model created in previous step\n", "\n", "xgboost_model = SageMakerModel(\n", " endpointInstanceType=None,\n", " endpointInitialInstanceCount=None,\n", " requestRowSerializer=LibSVMRequestRowSerializer(),\n", " responseRowDeserializer=XGBoostCSVRowDeserializer(),\n", " existingEndpointName=my_endpoint,\n", " endpointCreationPolicy=EndpointCreationPolicy.DO_NOT_CREATE\n", ")\n", "\n", "transformed_data = xgboost_model.transform(test_data.limit(5))\n", "transformed_data.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Sparkmagic (PySpark)", "language": "", "name": "pysparkkernel" }, "language_info": { "codemirror_mode": { "name": "python", "version": 2 }, "mimetype": "text/x-python", "name": "pyspark", "pygments_lexer": "python2" } }, "nbformat": 4, "nbformat_minor": 2 }