# from monai.utils import first, set_determinism import numpy from monai.transforms import ( AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, Orientationd, RandCropByPosNegLabeld, ScaleIntensityRanged, Spacingd, EnsureTyped, EnsureType, Invertd, ) from monai.handlers.utils import from_engine from monai.networks.nets import UNet from monai.networks.layers import Norm from monai.metrics import DiceMetric from monai.losses import DiceLoss from monai.inferers import sliding_window_inference from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch from monai.config import print_config from monai.apps import download_and_extract import torch import matplotlib.pyplot as plt import tempfile import shutil import os, sys, glob, argparse, json import logging from inference import model_fn, input_fn, predict_fn, output_fn logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) def train(args): is_distributed = len(args.hosts) > 1 and args.backend is not None logger.debug("Distributed training - {}".format(is_distributed)) use_cuda = args.num_gpus > 0 logger.debug("Number of gpus available - {}".format(args.num_gpus)) kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {} device = torch.device("cuda" if use_cuda else "cpu") batch_size = args.batch_size # set the seed for generating random numbers torch.manual_seed(args.seed) if use_cuda: torch.cuda.manual_seed(args.seed) # Load data train_images = sorted( glob.glob(os.path.join(args.data_dir, "imagesTr", "*.nii.gz"))) train_labels = sorted( glob.glob(os.path.join(args.data_dir, "labelsTr", "*.nii.gz"))) data_dicts = [ {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] train_files, val_files = data_dicts[:-9], data_dicts[-9:] ## transform for training dataset train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=( 1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), # user can also add other random transforms # RandAffined( # keys=['image', 'label'], # mode=('bilinear', 'nearest'), # prob=1.0, spatial_size=(96, 96, 96), # rotate_range=(0, 0, np.pi/15), # scale_range=(0.1, 0.1, 0.1)), EnsureTyped(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=( 1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), ] ) train_ds = CacheDataset( data=train_files, transform=train_transforms, cache_rate=1.0) # use batch_size=2 to load images and use RandCropByPosNegLabeld # to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True) val_ds = CacheDataset( data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = DataLoader(val_ds, batch_size=1) #create model device = torch.device("cuda:0") model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.Adam(model.parameters(), args.lr) dice_metric = DiceMetric(include_background=False, reduction="mean") #train model max_epochs = args.epochs val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) roi_size = (160, 160, 160) sw_batch_size = 4 val_outputs = sliding_window_inference( val_inputs, roi_size, sw_batch_size, model) val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] val_labels = [post_label(i) for i in decollate_batch(val_labels)] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 #torch.save(model.state_dict(), os.path.join(args.data_dir, "best_metric_model.pth")) print("saved new best metric model") save_model(model, args.model_dir) print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f} " f"at epoch: {best_metric_epoch}" ) # model code directory model_code_dir = os.path.join(args.model_dir, 'code') os.makedirs(model_code_dir) shutil.copy('/opt/ml/code/inference.py', model_code_dir) ## copy the inference file shutil.copy('/opt/ml/code/requirements.txt', model_code_dir) # copy requirement.txt def save_model(model, model_dir): logger.info("Saving the model.") path = os.path.join(model_dir, 'model.pth') # recommended way from http://pytorch.org/docs/master/notes/serialization.html # torch.save(model.state_dict(), path) torch.save(model, path) if __name__ == '__main__': parser = argparse.ArgumentParser() # Data and model checkpoints directories parser.add_argument('--batch-size', type=int, default=2, metavar='N', help='input batch size for training (default: 1000)') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 100)') parser.add_argument('--epochs', type=int, default=600, metavar='N', help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--backend', type=str, default=None, help='backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)') # Container environment parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'] if 'SM_CHANNEL_TRAINING' in os.environ else '/opt/ml/input/') parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS']) train(parser.parse_args())