In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
import boto3

In [None]:
s3_client = boto3.client("s3")
sess = sagemaker.session.Session()
role = sagemaker.get_execution_role()
bucket = sess.default_bucket()
key_prefix = "pt_lightning_ddp_tune"

In [None]:
!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz && tar xzf cifar-10-python.tar.gz && rm cifar-10-python.tar.gz

In [None]:
cifar_data_path = sess.upload_data("cifar-10-batches-py", bucket, key_prefix=f"{key_prefix}/input_data/cifar-10-batches-py")

In [None]:
cifar_path = "/".join(cifar_data_path.split("/")[:-1])

In [None]:
# optionally set subnets and security_groups
subnets=None
security_group_ids=None

In [None]:
estimator_gpu_tune_cifar = PyTorch(
 source_dir = "src",
 entry_point="tune_cifar.py",
 subnets=subnets,
 security_group_ids=security_group_ids,
 role=role,
 instance_count=2, 
 instance_type="ml.g4dn.xlarge", # instance with 1 GPUs. use g4dn.12xlarge or g5.12xlarge for multi-gpu instances
 framework_version="1.10",
 py_version="py38",
 hyperparameters={"use-gpu":True, # use GPU for training
 "num-samples":4, # number of trials to run for HPO
 "num-workers":2, # number of GPUs to use for each training run with Data Parallel distributed training
 "num-epochs":5} # number of epochs to train each model on
)

In [None]:
estimator_gpu_tune_cifar.fit({"train": cifar_path})