{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "source": [
    "# Spleen 3D segmentation with MONAI\n",
    "\n",
    "This tutorial shows how to run SageMaker managed training using MONAI for 3D Segmentation and SageMaker managed inference after model training. \n",
    "\n",
    "**Note**: select Kernel as *conda_pytorch_latest_p36*\n",
    "\n",
    "This notebook and train.py script in source folder were derived from [spleen_segmentation_3d notebook](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)\n",
    "\n",
    "Key features demonstrated here:\n",
    "1. SageMaker managed training with S3 integration\n",
    "2. SageMaker hosted inference \n",
    "\n",
    "The Spleen dataset can be downloaded from https://registry.opendata.aws/msd/.\n",
    "\n",
    "![spleen](http://medicaldecathlon.com/img/spleen0.png)\n",
    "\n",
    "Target: Spleen  \n",
    "Modality: CT  \n",
    "Size: 61 3D volumes (31 Training + 9 Validation + 1 Testing with label and 20 Testing without label)  \n",
    "Source: Memorial Sloan Kettering Cancer Center  \n",
    "Challenge: Large ranging foreground size\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install and import MONAI libraries "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install  \"monai[all]==0.8.0\"\n",
    "!python -c \"import monai\" || pip install -q \"monai-weekly[gdown, nibabel, tqdm, ignite]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from monai.utils import first, set_determinism\n",
    "from monai.transforms import (\n",
    "    AsDiscrete,\n",
    "    AsDiscreted,\n",
    "    EnsureChannelFirstd,\n",
    "    Compose,\n",
    "    CropForegroundd,\n",
    "    LoadImage,\n",
    "    LoadImaged,\n",
    "    Orientationd,\n",
    "    RandCropByPosNegLabeld,\n",
    "    ScaleIntensityRanged,\n",
    "    Spacingd,\n",
    "    EnsureTyped,\n",
    "    EnsureType,\n",
    "    Invertd\n",
    ")\n",
    "from monai.handlers.utils import from_engine\n",
    "from monai.networks.nets import UNet\n",
    "from monai.networks.layers import Norm\n",
    "from monai.metrics import DiceMetric\n",
    "from monai.losses import DiceLoss\n",
    "from monai.inferers import sliding_window_inference\n",
    "from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch\n",
    "from monai.config import print_config\n",
    "from monai.apps import download_and_extract\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import tempfile\n",
    "import shutil\n",
    "import os\n",
    "import glob\n",
    "import math\n",
    "import ast\n",
    "from pathlib import Path\n",
    "import boto3\n",
    "import sagemaker \n",
    "from sagemaker import get_execution_role\n",
    "from sagemaker.pytorch import PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import sagemaker libraries and get environment variables\n",
    "role = get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "region = sess.boto_session.region_name\n",
    "bucket = sess.default_bucket()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the dataset: Spleen dataset\n",
    "+ Download the Spleen dataset if it is not available locally\n",
    "+ Transform the images using Compose from MONAI\n",
    "+ Divide the image into training and testing dataset\n",
    "+ Visualize the image "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download images from public bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resource = \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar\"\n",
    "md5 = \"410d4a301da4e5b2f6f86ec3ddba524e\"\n",
    "compressed_file = \"./Task09_Spleen.tar\"\n",
    "\n",
    "data_dir = \"Spleen3D\" \n",
    "\n",
    "if not os.path.exists(data_dir):\n",
    "    download_and_extract(resource, compressed_file, f\"{data_dir}/datasets\", md5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image transforms\n",
    "\n",
    "* LoadImaged loads the spleen CT images and labels from NIfTI format files.\n",
    "* EnsureChannelFirstd automatically adjusts or add the channel dimension of input data to ensure channel_first shape \n",
    "* Spacingd adjusts the spacing by pixdim=(1.5, 1.5, 2.) based on the affine matrix.\n",
    "* Orientationd unifies the data orientation based on the affine matrix.\n",
    "* ScaleIntensityRanged extracts intensity range [-57, 164] and scales to [0, 1].\n",
    "* CropForegroundd removes all zero borders to focus on the valid body area of the images and labels.\n",
    "* EnsureTyped converts the numpy array to PyTorch Tensor for further steps.\n",
    "\n",
    "Selected dataset in training only:\n",
    "* RandCropByPosNegLabeld randomly crop patch samples from big image based on pos / neg ratio. The image centers of negative samples must be in valid body area."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## transform the images through Compose\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\", \"label\"]),  ## keys include image and label with image first\n",
    "        EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n",
    "        Spacingd(keys=[\"image\", \"label\"], pixdim=(\n",
    "            1.5, 1.5, 2.0), mode=(\"bilinear\", \"nearest\")),\n",
    "        Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n",
    "        ScaleIntensityRanged(\n",
    "            keys=[\"image\"], a_min=-57, a_max=164,\n",
    "            b_min=0.0, b_max=1.0, clip=True,\n",
    "        ),\n",
    "        CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n",
    "        EnsureTyped(keys=[\"image\", \"label\"]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Divide the images into training and testing dataset\n",
    "Split into 40 for training and 1 for inference and visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_images = sorted(\n",
    "    glob.glob(os.path.join(data_dir, \"datasets/Task09_Spleen/imagesTr\", \"*.nii.gz\")))\n",
    "train_labels = sorted(\n",
    "    glob.glob(os.path.join(data_dir, \"datasets/Task09_Spleen/labelsTr\", \"*.nii.gz\")))\n",
    "data_dicts = [\n",
    "    {\"image\": image_name, \"label\": label_name}\n",
    "    for image_name, label_name in zip(train_images, train_labels)\n",
    "]\n",
    "train_files, test_demo_files = data_dicts[:-1], data_dicts[-1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_demo_files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the image and label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "check_ds = Dataset(data=test_demo_files, transform=val_transforms)\n",
    "check_loader = DataLoader(check_ds, batch_size=1)\n",
    "check_data = first(check_loader)\n",
    "\n",
    "image, label = (check_data[\"image\"][0][0], check_data[\"label\"][0][0])\n",
    "print(f\"image shape: {image.shape}, label shape: {label.shape}\")\n",
    "# plot only the slice [:, :, 80]\n",
    "plt.figure(\"check\", (12, 6))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.title(\"image\")\n",
    "plt.imshow(image[:, :, 80], cmap=\"gray\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.title(\"label\")\n",
    "plt.imshow(label[:, :, 80])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model training \n",
    "\n",
    "+ Separately store the dataset into training and testing\n",
    "+ Upload the dataset into S3 \n",
    "+ SageMaker training job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix=\"MONAI_Segmentation\"\n",
    "\n",
    "processed_train_path = os.path.join(data_dir,\"processed\",\"train\")\n",
    "processed_test_path = os.path.join(data_dir,\"processed\",\"test\")\n",
    "\n",
    "processed_train_images_path = os.path.join(processed_train_path, \"imagesTr\")\n",
    "processed_train_labels_path = os.path.join(processed_train_path, \"labelsTr\")\n",
    "\n",
    "processed_test_images_path = os.path.join(processed_test_path, \"imagesTr\")\n",
    "processed_test_labels_path = os.path.join(processed_test_path, \"labelsTr\")\n",
    "\n",
    "Path(processed_train_images_path).mkdir(parents=True, exist_ok=True)\n",
    "Path(processed_train_labels_path).mkdir(parents=True, exist_ok=True)\n",
    "print(\"Directory '%s' created\" %processed_train_path)\n",
    "\n",
    "Path(processed_test_images_path).mkdir(parents=True, exist_ok=True)\n",
    "Path(processed_test_labels_path).mkdir(parents=True, exist_ok=True)\n",
    "print(\"Directory '%s' created\" %processed_test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## copy dataset for training \n",
    "for file in train_files:\n",
    "    images = file[\"image\"]\n",
    "    images_dest = processed_train_images_path\n",
    "    label = file[\"label\"]\n",
    "    label_dest = processed_train_labels_path\n",
    "    shutil.copy(images,images_dest)\n",
    "    shutil.copy(label,label_dest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## copy dataset for testing \n",
    "for file in test_demo_files:\n",
    "    images = file[\"image\"]\n",
    "    images_dest = processed_test_images_path\n",
    "    label = file[\"label\"]\n",
    "    label_dest = processed_test_labels_path\n",
    "    shutil.copy(images,images_dest)\n",
    "    shutil.copy(label,label_dest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload datasets to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## upload training dataset to S3\n",
    "S3_inputs = sess.upload_data(\n",
    "    path=processed_train_path,\n",
    "    key_prefix=f\"{prefix}/train\",\n",
    "    bucket=bucket \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## upload testing dataset to S3\n",
    "S3_demo_test = sess.upload_data(\n",
    "    path=processed_test_images_path,\n",
    "    key_prefix=f\"{prefix}/test\",\n",
    "    bucket=bucket \n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SageMaker training job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics=[\n",
    "   {\"Name\": \"train:average epoch loss\", \"Regex\": \"average loss: ([0-9\\\\.]*)\"},\n",
    "   {\"Name\": \"train:current mean dice\", \"Regex\": \"current mean dice: ([0-9\\\\.]*)\"},\n",
    "   {\"Name\": \"train:best mean dice\", \"Regex\": \"best mean dice: ([0-9\\\\.]*)\"}\n",
    "]\n",
    "\n",
    "estimator = PyTorch(source_dir=\"code\",\n",
    "                    entry_point=\"train.py\",\n",
    "                    role=role,\n",
    "                    framework_version=\"1.6.0\",\n",
    "                    py_version=\"py3\",\n",
    "                    instance_count=1,\n",
    "#                     instance_type=\"ml.p2.xlarge\",\n",
    "                    instance_type=\"ml.g4dn.2xlarge\",\n",
    "                    hyperparameters={\n",
    "                       \"seed\": 123,\n",
    "                       \"lr\": 0.001,\n",
    "                       \"epochs\": 20\n",
    "                    },\n",
    "                    metric_definitions=metrics,\n",
    "#                     ### spot instance training ###\n",
    "#                    use_spot_instances=True,\n",
    "#                     max_run=2400,\n",
    "#                     max_wait=2400\n",
    "                )\n",
    "\n",
    "\n",
    "estimator.fit(S3_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference \n",
    "\n",
    "+ Deploy the model with customized inference script and trained estimator - model artifact in S3.\n",
    "+ Inference with testing image in S3\n",
    "+ Visualize the results\n",
    "\n",
    "The endpoint will return two types of output. If an integer is provided for the slice number, it will return the inference result for that slice. If \"start slice\" and \"end slice\" are provided or if the input provided is \"all\" - referring to all slices, it will return the S3 location where the inference result is saved.\n",
    "\n",
    "Demonstrated in this notebook:\n",
    "1. Inference for multiple slices by looping the endpoint API calls\n",
    "2. Inference across multiple images and slices using loops\n",
    "3. Perform inference on a selection of slices\n",
    "4. Perform inference on all slices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create endpoint\n",
    "\n",
    "Challenge 1:\n",
    "+ Can you host the model using a selected instance type e.g.: \"ml.m5.4xlarge\" ? \n",
    "+ Can you add serializer and deserializer as json? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## realtime endpoint\n",
    "\n",
    "predictor = estimator.deploy(\n",
    "    initial_instance_count=1,\n",
    "    source_dir=\"code\",\n",
    "    entry_point=\"inference.py\", \n",
    "    instance_type=<to do>,\n",
    "    serializer=<to do>,\n",
    "    deserializer=<to do>\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference for multiple slices by looping the endpoint API calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "test_demo_preds=[]\n",
    "\n",
    "totalslice = np.array(image).shape[-1]\n",
    "nsliceend=10  #makesure nsliceend<=totalslice\n",
    "nslicestart = 0\n",
    "prefix_key = f\"{prefix}/test\"\n",
    "file = test_demo_files[0][\"image\"].split(\"/\")[-1]\n",
    "\n",
    "###Option 1 - use totalslice\n",
    "# for counter in range(totalslice): #for using totalslice\n",
    "\n",
    "###Option 1 - use \"nslicestart\" and \"nsliceend\"\n",
    "for counter in range(int(nslicestart),int(nsliceend)): #for using \"slicestart\" and \"sliceend\"\n",
    "    payload={\n",
    "        \"bucket\": bucket,\n",
    "        \"key\": prefix_key,\n",
    "        \"file\": file,\n",
    "        \"nslice\": counter\n",
    "            }\n",
    "    response_pred=predictor.predict(payload)\n",
    "    print(\"inference for slice\",counter)\n",
    "    test_demo_preds.append(response_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_demo_ds = check_ds\n",
    "test_demo_loader = check_loader\n",
    "test_demo_data = check_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nslice=1\n",
    "import sys\n",
    "sys.getsizeof(torch.tensor(test_demo_preds[nslice][\"pred\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the result for 1 slice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, label = (test_demo_data[\"image\"][0][0], test_demo_data[\"label\"][0][0])\n",
    "print(f\"image shape: {image.shape}, label shape: {label.shape}\")\n",
    "\n",
    "# Visualization\n",
    "# plot the slice [:, :, nslice]\n",
    "plt.figure(\"check\", (18, 6))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.title(\"image\")\n",
    "plt.imshow(test_demo_data[\"image\"][0, 0, :, :, nslicestart+nslice], cmap=\"gray\")\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.title(\"label\")\n",
    "plt.imshow(test_demo_data[\"label\"][0, 0, :, :, nslicestart+nslice])\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.title(\"output\")\n",
    "plt.imshow(test_demo_preds[nslice][\"pred\"])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Challenge 2: perform inference on a selection of slices and visualize them\n",
    "For inference across multiple image slices, the output file will be sent to S3 and the endpoint will output the S3 URI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "slicestart=70\n",
    "sliceend=75\n",
    "sliceselect = f\"{slicestart}:{sliceend}\"\n",
    "\n",
    "payload_multi={\n",
    "    \"bucket\": bucket,\n",
    "    \"key\": prefix_key,\n",
    "    \"file\": file,\n",
    "    \"nslice\": sliceselect\n",
    "        }\n",
    "\n",
    "response_multi_pred=predictor.predict(payload_multi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  Find the prediction results in S3 from response and download them locally \n",
    "    \n",
    "## to do challenge 1: find the results in S3(both through console and SageMaker SDK)\n",
    "## to do challenge 2: download the results\n",
    "## to do challenge 3: visualize the results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Perform inference on all slices\n",
    "\n",
    "For inference across all slices, the output file will be sent to S3 and the endpoint will output the S3 URI.\n",
    "\n",
    "in payload, change \"nslice\" to all to inference for all slides given a image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "payload_all={\n",
    "    \"bucket\": bucket,\n",
    "    \"key\": prefix_key,\n",
    "    \"file\": file,\n",
    "    \"nslice\": \"all\"\n",
    "        }\n",
    "\n",
    "response_all_pred=predictor.predict(payload_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the results "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up the resources\n",
    "\n",
    "+ delete the current endpoint or all the endpoints to save cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# predictor.delete_predictor(delete_endpoint_config=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# client = boto3.client(\"sagemaker\")\n",
    "# endpoints=client.list_endpoints()[\"Endpoints\"]\n",
    "# endpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for endpoint in endpoints:\n",
    "#     response = client.delete_endpoint(\n",
    "#         EndpointName=endpoint[\"EndpointName\"]\n",
    "#     )"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "conda_pytorch_latest_p36",
   "language": "python",
   "name": "conda_pytorch_latest_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}