# Spleen 3D segmentation with MONAI

This tutorial shows how to run SageMaker managed training using MONAI for 3D Segmentation.

This notebook and train.py script in source folder were derived from [this notebook](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)

Key features demonstrated here:
1. SageMaker managed training with EFS integration
2. SageMaker Hyperparameter tuning 

The Spleen dataset can be downloaded from https://registry.opendata.aws/msd/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen 
Modality: CT 
Size: 61 3D volumes (41 Training + 20 Testing) 
Source: Memorial Sloan Kettering Cancer Center 
Challenge: Large ranging foreground size
 

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
 AsDiscrete,
 AsDiscreted,
 EnsureChannelFirstd,
 Compose,
 CropForegroundd,
 LoadImaged,
 Orientationd,
 RandCropByPosNegLabeld,
 ScaleIntensityRanged,
 Spacingd,
 EnsureTyped,
 EnsureType,
 Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

In [None]:
## Download dataset if it is not available
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = "./Task09_Spleen.tar"

MONAILabelServerIP = "10.192.21.35" ## IP address of the MONAI Label Server if deployed
data_dir = MONAILabelServerIP

if not os.path.exists(data_dir):
 download_and_extract(resource, compressed_file, data_dir + "/datasets", md5)

In [None]:
val_transforms = Compose(
 [
 LoadImaged(keys=["image", "label"]),
 EnsureChannelFirstd(keys=["image", "label"]),
 Spacingd(
 keys=["image", "label"],
 pixdim=(1.5, 1.5, 2.0),
 mode=("bilinear", "nearest"),
 ),
 Orientationd(keys=["image", "label"], axcodes="RAS"),
 ScaleIntensityRanged(
 keys=["image"],
 a_min=-57,
 a_max=164,
 b_min=0.0,
 b_max=1.0,
 clip=True,
 ),
 CropForegroundd(keys=["image", "label"], source_key="image"),
 EnsureTyped(keys=["image", "label"]),
 ]
)

In [None]:
train_images = sorted(
 glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/imagesTr", "*.nii.gz"))
)
train_labels = sorted(
 glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/labelsTr", "*.nii.gz"))
)
data_dicts = [
 {"image": image_name, "label": label_name}
 for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

In [None]:
## To collect information like subnets and security group to submit the training job with EFS data sources;
import boto3

sm = boto3.client("sagemaker")
efs = boto3.client("efs")
ec2 = boto3.client("ec2")
sm_domains = sm.list_domains()
sm_domain = sm.describe_domain(DomainId=sm_domains["Domains"][0]["DomainId"])

UserProfileName = "sagemaker-userprofile-for-demo" ## this is hard code UserProfile name in CFN template, please replace it if needed
sm_user = sm.describe_user_profile(
 DomainId=sm_domains["Domains"][0]["DomainId"], UserProfileName=UserProfileName
)

## if the SageMaker studio domain and userprofile was created by CloudFormation deployment the UserSettings has security group associated, if not we will need to find the security group that can access Home EFS and grant egress
if "UserSettings" in sm_user and "SecurityGroups" in sm_user["UserSettings"]:
 training_securitygroup = sm_user["UserSettings"]["SecurityGroups"]
## the SageMaker studio execution role should have permissioon to describe mount target and authorize egress to security group
else:
 mounttargets = efs.describe_mount_targets(
 FileSystemId=sm_domain["HomeEfsFileSystemId"]
 )
 securitygroup = ec2.describe_security_groups(
 Filters=[
 {
 "Name": "group-id",
 "Values": efs.describe_mount_target_security_groups(
 MountTargetId=mounttargets["MountTargets"][0]["MountTargetId"]
 )["SecurityGroups"],
 }
 ]
 )["SecurityGroups"][0]

 ec2r = boto3.resource("ec2")
 securitygroup = ec2r.SecurityGroup(
 securitygroup["IpPermissions"][0]["UserIdGroupPairs"][0]["GroupId"]
 )
 securitygroup.authorize_egress(
 IpPermissions=[
 {
 "IpProtocol": "-1",
 "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
 "Ipv6Ranges": [],
 "PrefixListIds": [],
 "UserIdGroupPairs": [],
 }
 ]
 )
 training_securitygroup = [securitygroup.id]

In [None]:
import sagemaker
from sagemaker.inputs import FileSystemInput
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

metrics = [
 {"Name": "train:average epoch loss", "Regex": "average loss: ([0-9\\.]*)"},
 {"Name": "train:current mean dice", "Regex": "current mean dice: ([0-9\\.]*)"},
 {"Name": "train:best mean dice", "Regex": "best mean dice: ([0-9\\.]*)"},
]

estimator = PyTorch(
 source_dir="source",
 entry_point="train.py",
 role=role,
 framework_version="1.6.0",
 py_version="py3",
 instance_count=1,
 instance_type="ml.p2.xlarge",
 subnets=sm_domain["SubnetIds"],
 security_group_ids=training_securitygroup,
 hyperparameters={"seed": 2, "lr": 0.001, "epochs": 10},
 metric_definitions=metrics,
 # ### spot instance training ###
 # use_spot_instances=True,
 # max_run=2400,
 # max_wait=2400
)

NotebookHostPath = MONAILabelServerIP
file_system_input = FileSystemInput(
 file_system_id=sm_domain["HomeEfsFileSystemId"],
 file_system_type="EFS",
 directory_path="/{0}/{1}".format(sm_user["HomeEfsFileSystemUid"], NotebookHostPath),
 file_system_access_mode="rw",
)

# Start an Amazon SageMaker training job with EFS using the FileSystemInput class
estimator.fit(file_system_input)

In [None]:
## hyperparameter tuning (optional to run)

objective_metric_name = "train:current mean dice"

hyperparameter_ranges = {
 "lr": sagemaker.tuner.ContinuousParameter(0.001, 0.1),
 "epochs": sagemaker.tuner.CategoricalParameter([1, 5, 10]),
}

tuner = sagemaker.tuner.HyperparameterTuner(
 estimator,
 objective_metric_name,
 hyperparameter_ranges,
 metrics,
 max_jobs=1,
 max_parallel_jobs=1,
 objective_type="Maximize",
)

tuner.fit(file_system_input)

In [None]:
predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.p2.xlarge")