In [None]:
%load_ext nb_black

import sagemaker
import boto3

# 1. Prepare Training data

In [None]:
%%bash
if [ -f ../data/tsp-data.tar.gz ]; then
 echo "File tsp-data.tar.gz exists."
else
 echo "File tsp-data.tar.gz does not exist."
 gdown https://drive.google.com/uc?id=152mpCze-v4d0m9kdsCeVkLdHFkjeDeF5
 mv tsp-data.tar.gz ../
fi

In [None]:
%%bash
if [ -d ../data ]; then
 echo "Folder data exists."
else
 echo "Folder data does not exist."
 tar -xvzf ../tsp-data.tar.gz -C ../
fi

In [None]:
session = sagemaker.Session()
BUCKET = session.default_bucket() # Set a default S3 bucket

In [None]:
s3 = boto3.resource("s3")
for file in [
 "tsp20_test_concorde.txt",
 "tsp50_test_concorde.txt",
 "tsp100_test_concorde.txt",
]:
 s3.meta.client.upload_file(f"../data/tsp/{file}", BUCKET, f"data/tsp/{file}")

# 2. Distributed Training

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
role_name = role.split(["/"][-1])
print(f"The Amazon Resource Name (ARN) of the role used for this demo is: {role}")
print(f"The name of the role used for this demo is: {role_name[-1]}")

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
 base_job_name="pytorch-smdataparallel-tsp",
 source_dir="../src",
 entry_point="run.py",
 role=role,
 framework_version="1.8.1",
 py_version="py36",
 instance_count=1,
 instance_type="ml.p3.16xlarge",
 sagemaker_session=sagemaker_session,
 distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
 debugger_hook_config=False,
 hyperparameters={
 "problem": "tsp",
 "min_size": 50,
 "max_size": 50,
 "neighbors": 0.2,
 "knn_strat": "percentage",
 "n_epochs": 100,
 "epoch_size": 128000,
 "batch_size": 128,
 "accumulation_steps": 1,
 "train_dataset": "tsp20-50_train_concorde.txt",
 "val_datasets": "tsp20_test_concorde.txt tsp50_test_concorde.txt tsp100_test_concorde.txt",
 "val_size": 1280,
 "rollout_size": 1280,
 "model": "attention",
 "encoder": "gnn",
 "embedding_dim": 128,
 "hidden_dim": 512,
 "n_encode_layers": 3,
 "aggregation": "max",
 "normalization": "batch",
 "n_heads": 8,
 "tanh_clipping": 10.0,
 "lr_model": 0.0001,
 "lr_critic": 0.0001,
 "lr_decay": 1.0,
 "max_grad_norm": 1.0,
 "exp_beta": 0.8,
 "baseline": "rollout",
 "bl_alpha": 0.05,
 "bl_warmup_epochs": 0,
 "seed": 1234,
 "num_workers": 0,
 "log_step": 100,
 },
 metric_definitions=[
 {
 "Name": "val:gap_tsp20",
 "Regex": "tsp20_test_concorde.txt Validation optimality gap=(.*?)\%",
 },
 {
 "Name": "val:gap_tsp50",
 "Regex": "tsp50_test_concorde.txt Validation optimality gap=(.*?)\%",
 },
 {
 "Name": "val:gap_tsp100",
 "Regex": "tsp100_test_concorde.txt Validation optimality gap=(.*?)\%",
 },
 ],
 max_run=1 * 24 * 60 * 60,
)

In [None]:
estimator.fit(
 {"train": f"s3://{BUCKET}/data/tsp", "val": f"s3://{BUCKET}/data/tsp"}, wait=False
)