{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Solution b.\n", "\n", "Create a inference script. Let's call it `inference.py`.\n", "\n", "Let's also create the `input_fn`, `predict_fn`, `output_fn` and `model_fn` functions.\n", "\n", "Copy the cells below and paste in [the main notebook](../deployment_hosting.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile inference.py\n", "\n", "import os\n", "import pickle\n", "\n", "import xgboost\n", "import sagemaker_xgboost_container.encoder as xgb_encoders\n", "\n", "# Same as in the training script\n", "def model_fn(model_dir):\n", " \"\"\"Load a model. For XGBoost Framework, a default function to load a model is not provided.\n", " Users should provide customized model_fn() in script.\n", " Args:\n", " model_dir: a directory where model is saved.\n", " Returns:\n", " A XGBoost model.\n", " XGBoost model format type.\n", " \"\"\"\n", " model_files = (file for file in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, file)))\n", " model_file = next(model_files)\n", " try:\n", " booster = pickle.load(open(os.path.join(model_dir, model_file), 'rb'))\n", " format = 'pkl_format'\n", " except Exception as exp_pkl:\n", " try:\n", " booster = xgboost.Booster()\n", " booster.load_model(os.path.join(model_dir, model_file))\n", " format = 'xgb_format'\n", " except Exception as exp_xgb:\n", " raise ModelLoadInferenceError(\"Unable to load model: {} {}\".format(str(exp_pkl), str(exp_xgb)))\n", " booster.set_param('nthread', 1)\n", " return booster\n", "\n", "\n", "def input_fn(request_body, request_content_type):\n", " \"\"\"\n", " The SageMaker XGBoost model server receives the request data body and the content type,\n", " and invokes the `input_fn`.\n", " The input_fn that just validates request_content_type and prints\n", " \"\"\"\n", " \n", " print(\"Hello from the PRE-processing function!!!\")\n", " \n", " if request_content_type == \"text/csv\":\n", " return xgb_encoders.csv_to_dmatrix(request_body)\n", " else:\n", " raise ValueError(\n", " \"Content type {} is not supported.\".format(request_content_type)\n", " )\n", "\n", "def predict_fn(input_object, model):\n", " \"\"\"\n", " SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.\n", " \"\"\"\n", " return model.predict(input_object)[0]\n", "\n", "\n", "def output_fn(prediction, response_content_type):\n", " \"\"\"\n", " After invoking predict_fn, the model server invokes `output_fn`.\n", " An output_fn that just adds a column to the output and validates response_content_type\n", " \"\"\"\n", " print(\"Hello from the POST-processing function!!!\")\n", " \n", " appended_output = \"hello from post-processing function!!!\"\n", " predictions = [prediction, appended_output]\n", "\n", " if response_content_type == \"text/csv\":\n", " return ','.join(str(x) for x in predictions)\n", " else:\n", " raise ValueError(\"Content type {} is not supported.\".format(response_content_type))\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deploy the new model with the inference script:\n", "\n", "- find the S3 bucket where the artifact is stored (you can create a tarball and upload it to S3 or use another model that was previously created in SageMaker)\n", "\n", "#### Finding a previously trained model:\n", "\n", "Go to the Experiments tab in Studio again: \n", "![experiments_s3_artifact.png](../media/experiments_s3_artifact.png)\n", "\n", "Choose another trained model, such as the one trained with Framework mode (right-click and choose `Open in trial details`):\n", "\n", "![trial_s3_artifact.png](../media/trial_s3_artifact.png)\n", "\n", "Click on `Artifacts` and look at the `Output artifacts`:\n", "![trial_uri_s3_artifact.png](../media/trial_uri_s3_artifact.png)\n", "\n", "Copy and paste your `SageMaker.ModelArtifact` of the S3 URI where the model is saved:\n", "\n", "In this example:\n", "```\n", "s3_artifact=\"s3://sagemaker-studio-us-east-2-/xgboost-churn/output/demo-xgboost-customer-churn-2021-04-13-18-51-56-144/output/model.tar.gz\"\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_artifact=\"s3:///PATH/TO/model.tar.gz\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store -r docker_image_name\n", "%store -r framework_version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Deploy it:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.xgboost.model import XGBoostModel\n", "\n", "xgb_inference_model = XGBoostModel(\n", " entry_point=\"inference.py\",\n", " model_data=s3_artifact,\n", " role=role,\n", " image_uri=docker_image_name,\n", " framework_version=framework_version,\n", " py_version=\"py3\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_capture_prefix = '{}/datacapture'.format(prefix)\n", "\n", "endpoint_name = \"model-xgboost-customer-churn-\" + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n", "print(\"EndpointName = {}\".format(endpoint_name))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = xgb_inference_model.deploy( initial_instance_count=1, \n", " instance_type='ml.m4.xlarge',\n", " endpoint_name=endpoint_name,\n", " data_capture_config=DataCaptureConfig(\n", " enable_capture=True,\n", " sampling_percentage=100,\n", " destination_s3_uri='s3://{}/{}'.format(bucket, data_capture_prefix)\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Updating an existing endpoint\n", "# model_name = xgb_inference_model.name\n", "\n", "# from sagemaker.predictor import Predictor\n", "# predictor = Predictor(endpoint_name=endpoint_name)\n", "# predictor.update_endpoint(instance_type='ml.m4.xlarge', \n", "# initial_instance_count=1, \n", "# model_name=model_name,\n", "# data_capture_config=DataCaptureConfig(\n", "# enable_capture=True,\n", "# sampling_percentage=100,\n", "# destination_s3_uri='s3://{}/{}'.format(bucket, data_capture_prefix)\n", "# )\n", "# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Send some requests:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "runtime_client = boto3.client(\"sagemaker-runtime\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('/root/amazon-sagemaker-workshop/4-Deployment/RealTime/config/test_sample.csv', 'r') as f:\n", " for row in f:\n", " payload = row.rstrip('\\n')\n", " print(f\"Sending: {payload}\")\n", " response = runtime_client.invoke_endpoint(EndpointName=endpoint_name,\n", " ContentType='text/csv', \n", " Accept='text/csv', \n", " Body=payload)\n", " \n", " print(f\"\\nReceived: {response['Body'].read()}\")\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Go to CloudWatch logs and check the inference logic:\n", "\n", "[Link to CloudWatch Logs](https://us-east-2.console.aws.amazon.com/cloudwatch/home?region=us-east-2#logsV2:log-groups$3FlogGroupNameFilter$3D$252Faws$252Fsagemaker$252FEndpoints$252F)" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-2:429704687514:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }