{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import json\n",
    "\n",
    "from sagemaker.session import Session\n",
    "sagemaker_session = sagemaker.Session()\n",
    "region = sagemaker_session.boto_session.region_name\n",
    "\n",
    "BUCKET='gps-serverless-workshop'\n",
    "MODEL='sagemaker-tensorflow-2018-11-24-19-02-41-238'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OUTPUT = 'ServerlessAIWorkshop/BatchTransform/output'\n",
    "INPUT_BUCKET = 'sagemaker-sample-data-{}'.format(region)\n",
    "DATADIR = 'batch-transform/mnist-1000-samples'\n",
    "\n",
    "input_key = 'kmeans_batch_example/input/valid-data.csv'\n",
    "input_location = 's3://{}/{}'.format(INPUT_BUCKET, DATADIR)\n",
    "output_location = 's3://{}/{}'.format(BUCKET, OUTPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the transformer object\n",
    "transformer =sagemaker.transformer.Transformer(\n",
    "    base_transform_job_name='Serverless-Workshop',\n",
    "    model_name=MODEL,\n",
    "    instance_count=4,\n",
    "    instance_type='ml.c5.18xlarge',\n",
    "    output_path=output_location\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# To start a transform job:\n",
    "transformer.transform(input_location, content_type='text/csv', split_type='Line')\n",
    "# Then wait until transform job is completed\n",
    "transformer.wait()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3 = boto3.resource('s3')\n",
    "predictions = []\n",
    "for i in range(10):\n",
    "    file_key = '{}/data-{}.csv.out'.format(OUTPUT, i)\n",
    "\n",
    "    output_obj = s3.Object(BUCKET, file_key)\n",
    "    output = output_obj.get()[\"Body\"].read().decode('utf-8')\n",
    "\n",
    "    predictions.extend(json.loads(output)['outputs']['classes']['int64Val'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we can see the original labels are:\n",
    "\n",
    "```\n",
    "7, 2, 1, 0, 4, 1, 4, 9, 5, 9\n",
    "```\n",
    "\n",
    "Now let's print out the predictions to compare:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(', '.join(predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}