{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a83987d7", "metadata": {}, "outputs": [], "source": [ "%load_ext nb_black\n", "\n", "import sagemaker\n", "import boto3" ] }, { "cell_type": "markdown", "id": "4ecec369", "metadata": {}, "source": [ "# 1. Prepare Training data" ] }, { "cell_type": "code", "execution_count": null, "id": "93acdeda", "metadata": {}, "outputs": [], "source": [ "%%bash\n", "if [ -f ../data/tsp-data.tar.gz ]; then\n", " echo \"File tsp-data.tar.gz exists.\"\n", "else\n", " echo \"File tsp-data.tar.gz does not exist.\"\n", " gdown https://drive.google.com/uc?id=152mpCze-v4d0m9kdsCeVkLdHFkjeDeF5\n", " mv tsp-data.tar.gz ../\n", "fi" ] }, { "cell_type": "code", "execution_count": null, "id": "79c4a3a6", "metadata": {}, "outputs": [], "source": [ "%%bash\n", "if [ -d ../data ]; then\n", " echo \"Folder data exists.\"\n", "else\n", " echo \"Folder data does not exist.\"\n", " tar -xvzf ../tsp-data.tar.gz -C ../\n", "fi" ] }, { "cell_type": "code", "execution_count": null, "id": "92e2bb39", "metadata": {}, "outputs": [], "source": [ "session = sagemaker.Session()\n", "BUCKET = session.default_bucket() # Set a default S3 bucket" ] }, { "cell_type": "code", "execution_count": null, "id": "a821309d", "metadata": {}, "outputs": [], "source": [ "s3 = boto3.resource(\"s3\")\n", "for file in [\n", " \"tsp20_test_concorde.txt\",\n", " \"tsp50_test_concorde.txt\",\n", " \"tsp100_test_concorde.txt\",\n", "]:\n", " s3.meta.client.upload_file(f\"../data/tsp/{file}\", BUCKET, f\"data/tsp/{file}\")" ] }, { "cell_type": "markdown", "id": "90012792", "metadata": {}, "source": [ "# 2. Distributed Training" ] }, { "cell_type": "code", "execution_count": null, "id": "7b35e5ee", "metadata": {}, "outputs": [], "source": [ "sagemaker_session = sagemaker.Session()\n", "role = sagemaker.get_execution_role()\n", "role_name = role.split([\"/\"][-1])\n", "print(f\"The Amazon Resource Name (ARN) of the role used for this demo is: {role}\")\n", "print(f\"The name of the role used for this demo is: {role_name[-1]}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f63fb353", "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "estimator = PyTorch(\n", " base_job_name=\"pytorch-smdataparallel-tsp\",\n", " source_dir=\"../src\",\n", " entry_point=\"run.py\",\n", " role=role,\n", " framework_version=\"1.8.1\",\n", " py_version=\"py36\",\n", " instance_count=1,\n", " instance_type=\"ml.p3.16xlarge\",\n", " sagemaker_session=sagemaker_session,\n", " distribution={\"smdistributed\": {\"dataparallel\": {\"enabled\": True}}},\n", " debugger_hook_config=False,\n", " hyperparameters={\n", " \"problem\": \"tsp\",\n", " \"min_size\": 50,\n", " \"max_size\": 50,\n", " \"neighbors\": 0.2,\n", " \"knn_strat\": \"percentage\",\n", " \"n_epochs\": 100,\n", " \"epoch_size\": 128000,\n", " \"batch_size\": 128,\n", " \"accumulation_steps\": 1,\n", " \"train_dataset\": \"tsp20-50_train_concorde.txt\",\n", " \"val_datasets\": \"tsp20_test_concorde.txt tsp50_test_concorde.txt tsp100_test_concorde.txt\",\n", " \"val_size\": 1280,\n", " \"rollout_size\": 1280,\n", " \"model\": \"attention\",\n", " \"encoder\": \"gnn\",\n", " \"embedding_dim\": 128,\n", " \"hidden_dim\": 512,\n", " \"n_encode_layers\": 3,\n", " \"aggregation\": \"max\",\n", " \"normalization\": \"batch\",\n", " \"n_heads\": 8,\n", " \"tanh_clipping\": 10.0,\n", " \"lr_model\": 0.0001,\n", " \"lr_critic\": 0.0001,\n", " \"lr_decay\": 1.0,\n", " \"max_grad_norm\": 1.0,\n", " \"exp_beta\": 0.8,\n", " \"baseline\": \"rollout\",\n", " \"bl_alpha\": 0.05,\n", " \"bl_warmup_epochs\": 0,\n", " \"seed\": 1234,\n", " \"num_workers\": 0,\n", " \"log_step\": 100,\n", " },\n", " metric_definitions=[\n", " {\n", " \"Name\": \"val:gap_tsp20\",\n", " \"Regex\": \"tsp20_test_concorde.txt Validation optimality gap=(.*?)\\%\",\n", " },\n", " {\n", " \"Name\": \"val:gap_tsp50\",\n", " \"Regex\": \"tsp50_test_concorde.txt Validation optimality gap=(.*?)\\%\",\n", " },\n", " {\n", " \"Name\": \"val:gap_tsp100\",\n", " \"Regex\": \"tsp100_test_concorde.txt Validation optimality gap=(.*?)\\%\",\n", " },\n", " ],\n", " max_run=1 * 24 * 60 * 60,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "2aa9ef66", "metadata": {}, "outputs": [], "source": [ "estimator.fit(\n", " {\"train\": f\"s3://{BUCKET}/data/tsp\", \"val\": f\"s3://{BUCKET}/data/tsp\"}, wait=False\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "conda_tsp_deep_rl", "language": "python", "name": "conda_tsp_deep_rl" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }