{ "cells": [ { "cell_type": "markdown", "metadata": { "jupyter": { "outputs_hidden": true } }, "source": [ "# Spleen 3D segmentation with MONAI\n", "\n", "This tutorial shows how to run SageMaker managed training using MONAI for 3D Segmentation.\n", "\n", "This notebook and train.py script in source folder were derived from [this notebook](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)\n", "\n", "Key features demonstrated here:\n", "1. SageMaker managed training with EFS integration\n", "2. SageMaker Hyperparameter tuning \n", "\n", "The Spleen dataset can be downloaded from https://registry.opendata.aws/msd/.\n", "\n", "![spleen](http://medicaldecathlon.com/img/spleen0.png)\n", "\n", "Target: Spleen \n", "Modality: CT \n", "Size: 61 3D volumes (41 Training + 20 Testing) \n", "Source: Memorial Sloan Kettering Cancer Center \n", "Challenge: Large ranging foreground size\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python -c \"import monai\" || pip install -q \"monai-weekly[gdown, nibabel, tqdm, ignite]\"\n", "!python -c \"import matplotlib\" || pip install -q matplotlib\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from monai.utils import first, set_determinism\n", "from monai.transforms import (\n", " AsDiscrete,\n", " AsDiscreted,\n", " EnsureChannelFirstd,\n", " Compose,\n", " CropForegroundd,\n", " LoadImaged,\n", " Orientationd,\n", " RandCropByPosNegLabeld,\n", " ScaleIntensityRanged,\n", " Spacingd,\n", " EnsureTyped,\n", " EnsureType,\n", " Invertd,\n", ")\n", "from monai.handlers.utils import from_engine\n", "from monai.networks.nets import UNet\n", "from monai.networks.layers import Norm\n", "from monai.metrics import DiceMetric\n", "from monai.losses import DiceLoss\n", "from monai.inferers import sliding_window_inference\n", "from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch\n", "from monai.config import print_config\n", "from monai.apps import download_and_extract\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import tempfile\n", "import shutil\n", "import os\n", "import glob" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Download dataset if it is not available\n", "resource = \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar\"\n", "md5 = \"410d4a301da4e5b2f6f86ec3ddba524e\"\n", "compressed_file = \"./Task09_Spleen.tar\"\n", "\n", "MONAILabelServerIP = \"10.192.21.35\" ## IP address of the MONAI Label Server if deployed\n", "data_dir = MONAILabelServerIP\n", "\n", "if not os.path.exists(data_dir):\n", " download_and_extract(resource, compressed_file, data_dir + \"/datasets\", md5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_transforms = Compose(\n", " [\n", " LoadImaged(keys=[\"image\", \"label\"]),\n", " EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", " Spacingd(\n", " keys=[\"image\", \"label\"],\n", " pixdim=(1.5, 1.5, 2.0),\n", " mode=(\"bilinear\", \"nearest\"),\n", " ),\n", " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", " ScaleIntensityRanged(\n", " keys=[\"image\"],\n", " a_min=-57,\n", " a_max=164,\n", " b_min=0.0,\n", " b_max=1.0,\n", " clip=True,\n", " ),\n", " CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n", " EnsureTyped(keys=[\"image\", \"label\"]),\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_images = sorted(\n", " glob.glob(os.path.join(data_dir, \"datasets/Task09_Spleen/imagesTr\", \"*.nii.gz\"))\n", ")\n", "train_labels = sorted(\n", " glob.glob(os.path.join(data_dir, \"datasets/Task09_Spleen/labelsTr\", \"*.nii.gz\"))\n", ")\n", "data_dicts = [\n", " {\"image\": image_name, \"label\": label_name}\n", " for image_name, label_name in zip(train_images, train_labels)\n", "]\n", "train_files, val_files = data_dicts[:-9], data_dicts[-9:]\n", "\n", "check_ds = Dataset(data=val_files, transform=val_transforms)\n", "check_loader = DataLoader(check_ds, batch_size=1)\n", "check_data = first(check_loader)\n", "image, label = (check_data[\"image\"][0][0], check_data[\"label\"][0][0])\n", "print(f\"image shape: {image.shape}, label shape: {label.shape}\")\n", "# plot the slice [:, :, 80]\n", "plt.figure(\"check\", (12, 6))\n", "plt.subplot(1, 2, 1)\n", "plt.title(\"image\")\n", "plt.imshow(image[:, :, 80], cmap=\"gray\")\n", "plt.subplot(1, 2, 2)\n", "plt.title(\"label\")\n", "plt.imshow(label[:, :, 80])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## To collect information like subnets and security group to submit the training job with EFS data sources;\n", "import boto3\n", "\n", "sm = boto3.client(\"sagemaker\")\n", "efs = boto3.client(\"efs\")\n", "ec2 = boto3.client(\"ec2\")\n", "sm_domains = sm.list_domains()\n", "sm_domain = sm.describe_domain(DomainId=sm_domains[\"Domains\"][0][\"DomainId\"])\n", "\n", "UserProfileName = \"sagemaker-userprofile-for-demo\" ## this is hard code UserProfile name in CFN template, please replace it if needed\n", "sm_user = sm.describe_user_profile(\n", " DomainId=sm_domains[\"Domains\"][0][\"DomainId\"], UserProfileName=UserProfileName\n", ")\n", "\n", "## if the SageMaker studio domain and userprofile was created by CloudFormation deployment the UserSettings has security group associated, if not we will need to find the security group that can access Home EFS and grant egress\n", "if \"UserSettings\" in sm_user and \"SecurityGroups\" in sm_user[\"UserSettings\"]:\n", " training_securitygroup = sm_user[\"UserSettings\"][\"SecurityGroups\"]\n", "## the SageMaker studio execution role should have permissioon to describe mount target and authorize egress to security group\n", "else:\n", " mounttargets = efs.describe_mount_targets(\n", " FileSystemId=sm_domain[\"HomeEfsFileSystemId\"]\n", " )\n", " securitygroup = ec2.describe_security_groups(\n", " Filters=[\n", " {\n", " \"Name\": \"group-id\",\n", " \"Values\": efs.describe_mount_target_security_groups(\n", " MountTargetId=mounttargets[\"MountTargets\"][0][\"MountTargetId\"]\n", " )[\"SecurityGroups\"],\n", " }\n", " ]\n", " )[\"SecurityGroups\"][0]\n", "\n", " ec2r = boto3.resource(\"ec2\")\n", " securitygroup = ec2r.SecurityGroup(\n", " securitygroup[\"IpPermissions\"][0][\"UserIdGroupPairs\"][0][\"GroupId\"]\n", " )\n", " securitygroup.authorize_egress(\n", " IpPermissions=[\n", " {\n", " \"IpProtocol\": \"-1\",\n", " \"IpRanges\": [{\"CidrIp\": \"0.0.0.0/0\"}],\n", " \"Ipv6Ranges\": [],\n", " \"PrefixListIds\": [],\n", " \"UserIdGroupPairs\": [],\n", " }\n", " ]\n", " )\n", " training_securitygroup = [securitygroup.id]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker.inputs import FileSystemInput\n", "from sagemaker.pytorch import PyTorch\n", "\n", "sagemaker_session = sagemaker.Session()\n", "role = sagemaker.get_execution_role()\n", "\n", "metrics = [\n", " {\"Name\": \"train:average epoch loss\", \"Regex\": \"average loss: ([0-9\\\\.]*)\"},\n", " {\"Name\": \"train:current mean dice\", \"Regex\": \"current mean dice: ([0-9\\\\.]*)\"},\n", " {\"Name\": \"train:best mean dice\", \"Regex\": \"best mean dice: ([0-9\\\\.]*)\"},\n", "]\n", "\n", "estimator = PyTorch(\n", " source_dir=\"source\",\n", " entry_point=\"train.py\",\n", " role=role,\n", " framework_version=\"1.6.0\",\n", " py_version=\"py3\",\n", " instance_count=1,\n", " instance_type=\"ml.p2.xlarge\",\n", " subnets=sm_domain[\"SubnetIds\"],\n", " security_group_ids=training_securitygroup,\n", " hyperparameters={\"seed\": 2, \"lr\": 0.001, \"epochs\": 10},\n", " metric_definitions=metrics,\n", " # ### spot instance training ###\n", " # use_spot_instances=True,\n", " # max_run=2400,\n", " # max_wait=2400\n", ")\n", "\n", "NotebookHostPath = MONAILabelServerIP\n", "file_system_input = FileSystemInput(\n", " file_system_id=sm_domain[\"HomeEfsFileSystemId\"],\n", " file_system_type=\"EFS\",\n", " directory_path=\"/{0}/{1}\".format(sm_user[\"HomeEfsFileSystemUid\"], NotebookHostPath),\n", " file_system_access_mode=\"rw\",\n", ")\n", "\n", "# Start an Amazon SageMaker training job with EFS using the FileSystemInput class\n", "estimator.fit(file_system_input)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## hyperparameter tuning (optional to run)\n", "\n", "objective_metric_name = \"train:current mean dice\"\n", "\n", "hyperparameter_ranges = {\n", " \"lr\": sagemaker.tuner.ContinuousParameter(0.001, 0.1),\n", " \"epochs\": sagemaker.tuner.CategoricalParameter([1, 5, 10]),\n", "}\n", "\n", "tuner = sagemaker.tuner.HyperparameterTuner(\n", " estimator,\n", " objective_metric_name,\n", " hyperparameter_ranges,\n", " metrics,\n", " max_jobs=1,\n", " max_parallel_jobs=1,\n", " objective_type=\"Maximize\",\n", ")\n", "\n", "tuner.fit(file_system_input)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = estimator.deploy(initial_instance_count=1, instance_type=\"ml.p2.xlarge\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (PyTorch 1.8 Python 3.6 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/1.8.1-cpu-py36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }