{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Spleen 3D segmentation with MONAI\n", "\n", "This tutorial shows how to integrate MONAI into an existing PyTorch medical DL program.\n", "\n", "And easily use below features:\n", "1. Transforms for dictionary format data.\n", "1. Load Nifti image with metadata.\n", "1. Add channel dim to the data if no channel dimension.\n", "1. Scale medical image intensity with expected range.\n", "1. Crop out a batch of balanced images based on positive / negative label ratio.\n", "1. Cache IO and transforms to accelerate training and validation.\n", "1. 3D UNet model, Dice loss function, Mean Dice metric for 3D segmentation task.\n", "1. Sliding window inference method.\n", "1. Deterministic training for reproducibility.\n", "\n", "The Spleen dataset can be downloaded from http://medicaldecathlon.com/.\n", "\n", "![spleen](http://medicaldecathlon.com/img/spleen0.png)\n", "\n", "Target: Spleen \n", "Modality: CT \n", "Size: 61 3D volumes (41 Training + 20 Testing) \n", "Source: Memorial Sloan Kettering Cancer Center \n", "Challenge: Large ranging foreground size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "!pip uninstall -y monai monai-weekly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q \"monai[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide]==0.7.0\"\n", "#!pip list | grep monai" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#!python -c \"import monai\" || pip install -q \"monai-weekly[gdown, nibabel, tqdm]\"\n", "!python -c \"import matplotlib\" || pip install -q matplotlib\n", "# !pip install -q pytorch-lightning==1.4.0\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "from monai.utils import first, set_determinism\n", "from monai.transforms import (\n", " AsDiscrete,\n", " AsDiscreted,\n", " EnsureChannelFirstd,\n", " Compose,\n", " CropForegroundd,\n", " LoadImaged,\n", " Orientationd,\n", " RandCropByPosNegLabeld,\n", " ScaleIntensityRanged,\n", " Spacingd,\n", " EnsureTyped,\n", " EnsureType,\n", " Invertd,\n", ")\n", "from monai.handlers.utils import from_engine\n", "from monai.networks.nets import UNet\n", "from monai.networks.layers import Norm\n", "from monai.metrics import DiceMetric\n", "from monai.losses import DiceLoss\n", "from monai.inferers import sliding_window_inference\n", "from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch\n", "from monai.config import print_config\n", "from monai.apps import download_and_extract\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import tempfile\n", "import shutil\n", "import os\n", "import glob" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup data directory\n", "\n", "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n", "This allows you to save results and reuse downloads. \n", "If not specified a temporary directory will be used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pathlib, os, glob\n", "\n", "model_dir = \"10.192.21.7/models\" ## a folder to keep all the models\n", "pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)\n", "\n", "# make sure that the folder is empty\n", "files = glob.glob(model_dir + \"/*\")\n", "for f in files:\n", " os.remove(f)\n", "\n", "directory = \"10.192.21.7/datasets\" ## input data folder\n", "root_dir = tempfile.mkdtemp() if directory is None else directory\n", "data_dir = os.path.join(root_dir, \"Task09_Spleen\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download dataset\n", "\n", "Downloads and extracts the dataset. \n", "The dataset comes from http://medicaldecathlon.com/." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "resource = \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar\"\n", "md5 = \"410d4a301da4e5b2f6f86ec3ddba524e\"\n", "\n", "compressed_file = os.path.join(root_dir, \"Task09_Spleen.tar\")\n", "\n", "if not os.path.exists(data_dir):\n", " download_and_extract(resource, compressed_file, root_dir, md5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set MSD Spleen dataset path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_images = sorted(glob.glob(os.path.join(data_dir, \"imagesTr\", \"*.nii.gz\")))\n", "train_labels = sorted(glob.glob(os.path.join(data_dir, \"labelsTr\", \"*.nii.gz\")))\n", "data_dicts = [\n", " {\"image\": image_name, \"label\": label_name}\n", " for image_name, label_name in zip(train_images, train_labels)\n", "]\n", "train_files, val_files = (\n", " data_dicts[:-9],\n", " data_dicts[-9:],\n", ") ## keep the last 9 files as validation files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set deterministic training for reproducibility" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "set_determinism(seed=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup transforms for training and validation\n", "\n", "Here we use several transforms to augment the dataset:\n", "1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.\n", "1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct \"channel first\" shape.\n", "1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.\n", "1. `Orientationd` unifies the data orientation based on the affine matrix.\n", "1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].\n", "1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.\n", "1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio. \n", "The image centers of negative samples must be in valid body area.\n", "1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.\n", "1. `EnsureTyped` converts the numpy array to PyTorch Tensor for further steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_transforms = Compose(\n", " [\n", " LoadImaged(keys=[\"image\", \"label\"]),\n", " EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", " Spacingd(\n", " keys=[\"image\", \"label\"],\n", " pixdim=(1.5, 1.5, 2.0),\n", " mode=(\"bilinear\", \"nearest\"),\n", " ),\n", " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", " ScaleIntensityRanged(\n", " keys=[\"image\"],\n", " a_min=-57,\n", " a_max=164,\n", " b_min=0.0,\n", " b_max=1.0,\n", " clip=True,\n", " ),\n", " CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n", " RandCropByPosNegLabeld(\n", " keys=[\"image\", \"label\"],\n", " label_key=\"label\",\n", " spatial_size=(96, 96, 96),\n", " pos=1,\n", " neg=1,\n", " num_samples=4,\n", " image_key=\"image\",\n", " image_threshold=0,\n", " ),\n", " # user can also add other random transforms\n", " # RandAffined(\n", " # keys=['image', 'label'],\n", " # mode=('bilinear', 'nearest'),\n", " # prob=1.0, spatial_size=(96, 96, 96),\n", " # rotate_range=(0, 0, np.pi/15),\n", " # scale_range=(0.1, 0.1, 0.1)),\n", " EnsureTyped(keys=[\"image\", \"label\"]),\n", " ]\n", ")\n", "val_transforms = Compose(\n", " [\n", " LoadImaged(keys=[\"image\", \"label\"]),\n", " EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", " Spacingd(\n", " keys=[\"image\", \"label\"],\n", " pixdim=(1.5, 1.5, 2.0),\n", " mode=(\"bilinear\", \"nearest\"),\n", " ),\n", " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", " ScaleIntensityRanged(\n", " keys=[\"image\"],\n", " a_min=-57,\n", " a_max=164,\n", " b_min=0.0,\n", " b_max=1.0,\n", " clip=True,\n", " ),\n", " CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n", " EnsureTyped(keys=[\"image\", \"label\"]),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check transforms in DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "check_ds = Dataset(data=val_files, transform=val_transforms)\n", "check_loader = DataLoader(check_ds, batch_size=1)\n", "check_data = first(check_loader)\n", "image, label = (check_data[\"image\"][0][0], check_data[\"label\"][0][0])\n", "print(f\"image shape: {image.shape}, label shape: {label.shape}\")\n", "# plot the slice [:, :, 80]\n", "plt.figure(\"check\", (12, 6))\n", "plt.subplot(1, 2, 1)\n", "plt.title(\"image\")\n", "plt.imshow(image[:, :, 80], cmap=\"gray\")\n", "plt.subplot(1, 2, 2)\n", "plt.title(\"label\")\n", "plt.imshow(label[:, :, 80])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define CacheDataset and DataLoader for training and validation\n", "\n", "Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset. \n", "To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value. \n", "Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings. \n", "And set `num_workers` to enable multi-threads during caching. \n", "If want to to try the regular Dataset, just change to use the commented code below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "train_ds = CacheDataset(\n", " data=train_files, transform=train_transforms, cache_rate=1, num_workers=0\n", ")\n", "# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)\n", "\n", "# use batch_size=2 to load images and use RandCropByPosNegLabeld\n", "# to generate 2 x 4 images for network training\n", "train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)\n", "\n", "val_ds = CacheDataset(\n", " data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0\n", ")\n", "# val_ds = Dataset(data=val_files, transform=val_transforms)\n", "val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)" ] }, { "cell_type": "markdown", "metadata": { "jupyter": { "outputs_hidden": true } }, "source": [ "!top" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Model, Loss, Optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer\n", "device = torch.device(\"cuda:0\")\n", "model = UNet(\n", " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=2,\n", " channels=(16, 32, 64, 128, 256),\n", " strides=(2, 2, 2, 2),\n", " num_res_units=2,\n", " norm=Norm.BATCH,\n", ").to(device)\n", "loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n", "optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n", "dice_metric = DiceMetric(include_background=False, reduction=\"mean\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Execute a typical PyTorch training process" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "%%time\n", "max_epochs = 20 # 600\n", "val_interval = 2\n", "best_metric = -1\n", "best_metric_epoch = -1\n", "epoch_loss_values = []\n", "metric_values = []\n", "post_pred = Compose(\n", " [EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)]\n", ")\n", "post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])\n", "\n", "for epoch in range(max_epochs):\n", " print(\"-\" * 10)\n", " print(f\"epoch {epoch + 1}/{max_epochs}\")\n", " model.train()\n", " epoch_loss = 0\n", " step = 0\n", " for batch_data in train_loader:\n", " step += 1\n", " inputs, labels = (\n", " batch_data[\"image\"].to(device),\n", " batch_data[\"label\"].to(device),\n", " )\n", " optimizer.zero_grad()\n", " outputs = model(inputs)\n", " loss = loss_function(outputs, labels)\n", " loss.backward()\n", " optimizer.step()\n", " epoch_loss += loss.item()\n", " print(\n", " f\"{step}/{len(train_ds) // train_loader.batch_size}, \"\n", " f\"train_loss: {loss.item():.4f}\"\n", " )\n", " epoch_loss /= step\n", " epoch_loss_values.append(epoch_loss)\n", " print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n", "\n", " if (epoch + 1) % val_interval == 0:\n", " model.eval()\n", " with torch.no_grad():\n", " for val_data in val_loader:\n", " val_inputs, val_labels = (\n", " val_data[\"image\"].to(device),\n", " val_data[\"label\"].to(device),\n", " )\n", " roi_size = (160, 160, 160)\n", " sw_batch_size = 4\n", " val_outputs = sliding_window_inference(\n", " val_inputs, roi_size, sw_batch_size, model\n", " )\n", " val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n", " val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n", " # compute metric for current iteration\n", " dice_metric(y_pred=val_outputs, y=val_labels)\n", "\n", " # aggregate the final mean dice result\n", " metric = dice_metric.aggregate().item()\n", " # reset the status for next validation round\n", " dice_metric.reset()\n", "\n", " metric_values.append(metric)\n", " if metric > best_metric:\n", " best_metric = metric\n", " best_metric_epoch = epoch + 1\n", " torch.save(model.state_dict(), model_dir + \"/best_metric_model.pth\")\n", " print(\"saved new best metric model\")\n", " print(\n", " f\"current epoch: {epoch + 1} current mean dice: {metric:.4f}\"\n", " f\"\\nbest mean dice: {best_metric:.4f} \"\n", " f\"at epoch: {best_metric_epoch}\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "print(\n", " f\"train completed, best_metric: {best_metric:.4f} \" f\"at epoch: {best_metric_epoch}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the loss and metric" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(\"train\", (12, 6))\n", "plt.subplot(1, 2, 1)\n", "plt.title(\"Epoch Average Loss\")\n", "x = [i + 1 for i in range(len(epoch_loss_values))]\n", "y = epoch_loss_values\n", "plt.xlabel(\"epoch\")\n", "plt.plot(x, y)\n", "plt.subplot(1, 2, 2)\n", "plt.title(\"Val Mean Dice\")\n", "x = [val_interval * (i + 1) for i in range(len(metric_values))]\n", "y = metric_values\n", "plt.xlabel(\"epoch\")\n", "plt.plot(x, y)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check best model output with the input image and label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.load_state_dict(torch.load(model_dir + \"/best_metric_model.pth\"))\n", "model.eval()\n", "with torch.no_grad():\n", " for i, val_data in enumerate(val_loader):\n", " roi_size = (160, 160, 160)\n", " sw_batch_size = 4\n", " val_outputs = sliding_window_inference(\n", " val_data[\"image\"].to(device), roi_size, sw_batch_size, model\n", " )\n", " # plot the slice [:, :, 80]\n", " plt.figure(\"check\", (18, 6))\n", " plt.subplot(1, 3, 1)\n", " plt.title(f\"image {i}\")\n", " plt.imshow(val_data[\"image\"][0, 0, :, :, 80], cmap=\"gray\")\n", " plt.subplot(1, 3, 2)\n", " plt.title(f\"label {i}\")\n", " plt.imshow(val_data[\"label\"][0, 0, :, :, 80])\n", " plt.subplot(1, 3, 3)\n", " plt.title(f\"output {i}\")\n", " plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80])\n", " plt.show()\n", " if i == 2:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on original image spacings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_org_transforms = Compose(\n", " [\n", " LoadImaged(keys=[\"image\", \"label\"]),\n", " EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", " Spacingd(keys=[\"image\"], pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n", " Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n", " ScaleIntensityRanged(\n", " keys=[\"image\"],\n", " a_min=-57,\n", " a_max=164,\n", " b_min=0.0,\n", " b_max=1.0,\n", " clip=True,\n", " ),\n", " CropForegroundd(keys=[\"image\"], source_key=\"image\"),\n", " EnsureTyped(keys=[\"image\", \"label\"]),\n", " ]\n", ")\n", "\n", "val_org_ds = Dataset(data=val_files, transform=val_org_transforms)\n", "val_org_loader = DataLoader(val_org_ds, batch_size=1, num_workers=4)\n", "\n", "post_transforms = Compose(\n", " [\n", " EnsureTyped(keys=\"pred\"),\n", " Invertd(\n", " keys=\"pred\",\n", " transform=val_org_transforms,\n", " orig_keys=\"image\",\n", " meta_keys=\"pred_meta_dict\",\n", " orig_meta_keys=\"image_meta_dict\",\n", " meta_key_postfix=\"meta_dict\",\n", " nearest_interp=False,\n", " to_tensor=True,\n", " ),\n", " AsDiscreted(keys=\"pred\", argmax=True, to_onehot=True, num_classes=2),\n", " AsDiscreted(keys=\"label\", to_onehot=True, num_classes=2),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cleanup data directory\n", "\n", "Remove directory if a temporary was used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if directory is None:\n", " shutil.rmtree(root_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.g4dn.2xlarge", "kernelspec": { "display_name": "Python 3 (PyTorch 1.6 Python 3.6 GPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.6-gpu-py36-cu110-ubuntu18.04-v3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }