{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "import json\n", "\n", "from sagemaker.session import Session\n", "sagemaker_session = sagemaker.Session()\n", "region = sagemaker_session.boto_session.region_name\n", "\n", "BUCKET='gps-serverless-workshop'\n", "MODEL='sagemaker-tensorflow-2018-11-24-19-02-41-238'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "OUTPUT = 'ServerlessAIWorkshop/BatchTransform/output'\n", "INPUT_BUCKET = 'sagemaker-sample-data-{}'.format(region)\n", "DATADIR = 'batch-transform/mnist-1000-samples'\n", "\n", "input_key = 'kmeans_batch_example/input/valid-data.csv'\n", "input_location = 's3://{}/{}'.format(INPUT_BUCKET, DATADIR)\n", "output_location = 's3://{}/{}'.format(BUCKET, OUTPUT)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize the transformer object\n", "transformer =sagemaker.transformer.Transformer(\n", " base_transform_job_name='Serverless-Workshop',\n", " model_name=MODEL,\n", " instance_count=4,\n", " instance_type='ml.c5.18xlarge',\n", " output_path=output_location\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# To start a transform job:\n", "transformer.transform(input_location, content_type='text/csv', split_type='Line')\n", "# Then wait until transform job is completed\n", "transformer.wait()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3 = boto3.resource('s3')\n", "predictions = []\n", "for i in range(10):\n", " file_key = '{}/data-{}.csv.out'.format(OUTPUT, i)\n", "\n", " output_obj = s3.Object(BUCKET, file_key)\n", " output = output_obj.get()[\"Body\"].read().decode('utf-8')\n", "\n", " predictions.extend(json.loads(output)['outputs']['classes']['int64Val'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we can see the original labels are:\n", "\n", "```\n", "7, 2, 1, 0, 4, 1, 4, 9, 5, 9\n", "```\n", "\n", "Now let's print out the predictions to compare:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(', '.join(predictions))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }