{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune GPT-2 with near-linear scaling using Sharded Data Parallelism technique in SageMaker Model Parallelism Library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, you learn how to fine-tune the Hugging Face Transformers GPT-2 model with the [Sharded Data Parallelism](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html) technique in [SageMaker's Model Parallelism library (SMP)](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html) with PyTorch 1.13 and [GLUE/SST2 dataset](https://huggingface.co/datasets/glue/viewer/sst2/train) on SageMaker. \n", "\n", "The GPT-2 model was proposed by OpenAI in the paper [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). The original GPT-2 is a large transformer-based language model with 1.5 billion parameters. In this notebook, you can experiment with the model parameters to achieve different model sizes. This notebook uses the [Hugging Face Transformers GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html) implementation with the SMP integration.\n", "\n", "Sharded data parallelism is a distributed training technique that splits the model parameters, gradients, and optimizer states across GPUs in a data parallel group. It is purpose-built for extreme-scale models and leverages Amazon in-house [MiCS](https://arxiv.org/pdf/2205.00119.pdf) technology which achieves a near-linear scaling efficiency. For large models that cannot fit into a single GPU, we recommend to use the sharded data parallelism technique with [Activation Checkpointing](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-activation-checkpointing.html) and [Activation Offloading](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-activation-offloading.html) in SMP first, before leveraging other techniques such as tensor parallelism or pipeline parallelism.\n", "\n", "\n", "This notebook is accompanied with the following files:\n", "\n", "- `train.py`: The entry point script that'll be passed to the SageMaker PyTorch estimator later in this notebook when launching the training job. This script is prepared to run an end-to-end training of the GPT-2 model with SMP, settings for sharded data parallelism applied, and implemented with code lines to save, load, and fine-tune the model. You can follow the comments throughout the script to learn where the SMP APIs and code modifications are implemented.\n", "- `data_pipeline.py`: This has data pipeline functions to prepare the training dataset.\n", "- `learining_rate.py`: This has functions for learning rate schedule.\n", "- `requirements.txt`: This installs the dependencies, including huggingface transformers.\n", "- `memory_tracker.py`: This has functions to track memory usage.\n", "- `model_config.py`: This has functions to get model configuration information.\n", "- `sdp_utils.py`: This has util functions for sharded data parallelism\n", "\n", "### Additional Resources\n", "- To learn more about the SageMaker model parallelism library, see [Model Parallel Distributed Training with SageMaker Distributed](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html).\n", "\n", "- To learn more about using the SageMaker Python SDK with PyTorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).\n", "\n", "- To learn more about launching a training job in Amazon SageMaker with your own training image, see [Use Your Own Training Algorithms](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html).\n", "\n", "- To learn more about sharded data parallelism, check [Sharded Data Parallelism](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html) or the blog [Near-linear scaling of gigantic-model training on AWS](https://www.amazon.science/blog/near-linear-scaling-of-gigantic-model-training-on-aws).\n", "\n", "### Prerequisites\n", "You must create an S3 bucket to store the input data for training. This bucket must be located in the same AWS Region that you choose to launch your training job. To learn how to create a S3 bucket, see [Create your first S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html) in the *Amazon S3 documentation*.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Amazon SageMaker Initialization\n", "\n", "Run the following cell to import SageMaker modules and retrieve information of your current SageMaker work environment, such as your AWS account ID, the AWS Region, and the ARN of your Amazon SageMaker execution role. Upgrade SageMaker SDK to the latest version. \n", "\n", "**NOTE:** This step might require a kernel restart." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install --upgrade sagemaker\n", "%pip install sagemaker-experiments" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "import os\n", "\n", "import boto3\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.pytorch import PyTorch\n", "\n", "role = (\n", " get_execution_role()\n", ") # provide a pre-existing role ARN as an alternative to creating a new role\n", "print(f\"SageMaker Execution Role: {role}\")\n", "\n", "client = boto3.client(\"sts\")\n", "account = client.get_caller_identity()[\"Account\"]\n", "print(f\"AWS account: {account}\")\n", "\n", "session = boto3.session.Session()\n", "region = session.region_name\n", "print(f\"AWS region: {region}\")\n", "\n", "sm_boto_client = boto3.client(\"sagemaker\")\n", "sagemaker_session = sagemaker.session.Session(boto_session=session)\n", "\n", "# get default bucket\n", "default_bucket = sagemaker_session.default_bucket()\n", "print()\n", "print(\"Default bucket for this session: \", default_bucket)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download and prepare GLUE/SST2 data\n", "Here you will download, prepare the GLUE/SST2 dataset and then copy the files to S3." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Install the Hugging Face Transformers and Datasets libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install -q datasets transformers==4.21.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import datasets\n", "from datasets import load_dataset, load_from_disk, load_metric" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "import transformers\n", "import logging\n", "\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", ")\n", "\n", "from transformers.testing_utils import CaptureLogger" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load data\n", "This section loads the [GLUE/SST2](https://huggingface.co/datasets/glue/viewer/sst2/train) dataset and splits it to training and validation datasets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {\n", " \"dataset_name\": \"glue\",\n", " \"dataset_config_name\": \"sst2\",\n", " \"do_train\": True,\n", " \"do_eval\": True,\n", " \"cache_dir\": \"tmp\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_datasets = load_dataset(\n", " hyperparameters[\"dataset_name\"],\n", " hyperparameters[\"dataset_config_name\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"validation\" not in raw_datasets.keys():\n", " raw_datasets[\"validation\"] = load_dataset(\n", " hyperparameters[\"dataset_name\"],\n", " hyperparameters[\"dataset_config_name\"],\n", " split=\"train[:5%]\",\n", " cache_dir=hyperparameters[\"cache_dir\"],\n", " )\n", "\n", " raw_datasets[\"train\"] = load_dataset(\n", " hyperparameters[\"dataset_name\"],\n", " hyperparameters[\"dataset_config_name\"],\n", " split=\"train[5%:]\",\n", " cache_dir=hyperparameters[\"cache_dir\"],\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load tokenizer\n", "Nearly every NLP task begins with a tokenizer. A tokenizer converts your text data into a format (token) that can be processed by the NLP model.\n", "The following cell loads a tokenizer for GPT-2 using [AutoTokenizer.from_pretrained()](https://huggingface.co/docs/transformers/v4.19.4/en/autoclass_tutorial#autotokenizer)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer_kwargs = {\n", " \"cache_dir\": hyperparameters[\"cache_dir\"],\n", "}\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\", **tokenizer_kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocess data\n", "\n", "The following two cells set up a function to run the tokenizer and group texts into chunks smaller than the block size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tokenize_function(examples):\n", " tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n", "\n", " with CaptureLogger(tok_logger) as cl:\n", " output = tokenizer(examples[text_column_name])\n", " # clm input could be much much longer than block_size\n", " if \"Token indices sequence length is longer than the\" in cl.out:\n", " tok_logger.warning(\n", " \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\"\n", " )\n", " return output\n", "\n", "\n", "# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n", "def group_texts(examples):\n", " # Concatenate all texts.\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", " # customize this part to your needs.\n", " if total_length >= block_size:\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of max_len.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "column_names = raw_datasets[\"train\"].column_names\n", "text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n", "\n", "# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n", "tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n", "\n", "tokenized_datasets = raw_datasets.map(\n", " tokenize_function,\n", " batched=True,\n", " num_proc=1,\n", " remove_columns=column_names,\n", " desc=\"Running tokenizer on dataset\",\n", ")\n", "\n", "\n", "block_size = tokenizer.model_max_length\n", "if block_size > 1024:\n", " logger.warning(\n", " f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n", " \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n", " )\n", " block_size = 1024\n", "else:\n", " if block_size > tokenizer.model_max_length:\n", " logger.warning(\n", " f\"The block_size passed ({block_size}) is larger than the maximum length for the model\"\n", " f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n", " )\n", " block_size = min(block_size, tokenizer.model_max_length)\n", "\n", "lm_datasets = tokenized_datasets.map(\n", " group_texts,\n", " batched=True,\n", " # num_proc=args.preprocessing_num_workers,\n", " desc=f\"Grouping texts in chunks of {block_size}\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set additional hyperparameters and S3 paths for mapping the train and validation datasets properly depending on the phase (training or validation) of the training job in each epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if hyperparameters[\"do_train\"]:\n", " if \"train\" not in tokenized_datasets:\n", " raise ValueError(\"--do_train requires a train dataset\")\n", " train_dataset = lm_datasets[\"train\"]\n", "\n", "\n", "if hyperparameters[\"do_eval\"]:\n", " if \"validation\" not in tokenized_datasets:\n", " raise ValueError(\"--do_eval requires a validation dataset\")\n", " eval_dataset = lm_datasets[\"validation\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_dataset_location = None\n", "validation_dataset_location = None\n", "\n", "\n", "if hyperparameters[\"do_train\"]:\n", " train_dataset.to_json(\"./training.json\")\n", " training_dataset_location = \"s3://{}/dataset/train/\".format(default_bucket)\n", "\n", "if hyperparameters[\"do_eval\"]:\n", " eval_dataset.to_json(\"./validation.json\")\n", " validation_dataset_location = \"s3://{}/dataset/validation/\".format(default_bucket)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if training_dataset_location is not None:\n", " command = \"aws s3 cp ./training.json {}\".format(training_dataset_location)\n", " os.system(command)\n", "\n", "if validation_dataset_location is not None:\n", " command = \"aws s3 cp ./validation.json {}\".format(validation_dataset_location)\n", " os.system(command)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if hyperparameters[\"do_train\"]:\n", " command = \"rm ./training.json\"\n", " os.system(command)\n", "\n", "if hyperparameters[\"do_eval\"]:\n", " command = \"rm ./validation.json\"\n", " os.system(command)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store training_dataset_location\n", "%store validation_dataset_location" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Specify Amazon S3 bucket paths" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here you need to specify the paths for training data to be used by your job. The bucket used must be in the same region as where training will run. In the cells above you downloaded the GLUE/SST2 training and validation split datasets and uploaded the json files in an S3 bucket in your account. This example will train on those json files.\n", "\n", "After you successfully run this example tensor parallel training job, you can modify the S3 bucket to where your own dataset is stored." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%store -r training_dataset_location\n", "%store -r validation_dataset_location\n", "\n", "# if you're bringing your own data, uncomment the following lines and specify the locations there\n", "# training_dataset_location = YOUR_S3_BUCKET/training\n", "# validation_dataset_location = YOUR_S3_BUCKET/validation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_train_bucket = training_dataset_location\n", "s3_test_bucket = validation_dataset_location" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following S3 bucket will store the output artifacts of the training job. You can modify this as needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_output_bucket = f\"s3://sagemaker-{region}-{account}/smp-tensorparallel-outputdir/\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Data Channels for SageMaker Training Using Amazon S3\n", "\n", "In this step, define SageMaker training data channels to the S3 buckets. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set use_fsx to False by default\n", "# Set below var to True if you want to use fsx (see next cell)\n", "use_fsx = False\n", "if not use_fsx:\n", " if s3_train_bucket != None:\n", " train = sagemaker.inputs.TrainingInput(\n", " s3_train_bucket, distribution=\"FullyReplicated\", s3_data_type=\"S3Prefix\"\n", " )\n", " data_channels = {\"train\": train}\n", " else:\n", " data_channels = {\"train\": mock_data}\n", " if s3_test_bucket != None:\n", " test = sagemaker.inputs.TrainingInput(\n", " s3_test_bucket, distribution=\"FullyReplicated\", s3_data_type=\"S3Prefix\"\n", " )\n", " data_channels[\"test\"] = test\n", " else:\n", " data_channels[\"test\"] = mock_data\n", " print(data_channels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## (Optional) Set Up and Use Amazon FSx for Data Channels and Checkpoints\n", "\n", "While the previous option of using Amazon S3 is easier to setup, using an FSx can be beneficial for performance when dealing with large input sizes and large model sizes. If you are using models above 13B, checkpointing should be done using FSx. \n", "\n", "Please see the instructions from [Distributed Training of Mask-RCNN in Amazon SageMaker Using FSx](https://github.com/aws/amazon-sagemaker-examples/blob/master/advanced_functionality/distributed_tensorflow_mask_rcnn/mask-rcnn-scriptmode-fsx.ipynb) to create an FSx Lustre file system and import the dataset from the S3 bucket to your FSx file system. Note that the FSx file system must be created in a private subnet with internet gateway to ensure that training job has access to the internet. For general guidance on setting an FSx Lustre file system as data input channel, see [Configure Data Input Channel to Use Amazon FSx for Lustre](https://docs.aws.amazon.com/sagemaker/latest/dg/model-access-training-data.html#model-access-training-data-fsx)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Instructions obtained from:\n", "# https://github.com/aws/amazon-sagemaker-examples/blob/master/advanced_functionality/distributed_tensorflow_mask_rcnn/mask-rcnn-scriptmode-fsx.ipynb\n", "\n", "if use_fsx:\n", " from sagemaker.inputs import FileSystemInput\n", "\n", " # Specify FSx Lustre file system id.\n", " file_system_id = \"\"\n", "\n", " # Specify the SG and subnet used by the FSX, these are passed to SM Estimator so jobs use this as well\n", " fsx_security_group_id = \"\"\n", " fsx_subnet = \"\"\n", "\n", " # Specify directory path for input data on the file system.\n", " # You need to provide normalized and absolute path below.\n", " # Your mount name can be provided by you when creating fsx, or generated automatically.\n", " # You can find this mount_name on the FSX page in console.\n", " # Example of fsx generated mount_name: \"3x5lhbmv\"\n", " base_path = \"\"\n", "\n", " # Specify your file system type.\n", " file_system_type = \"FSxLustre\"\n", "\n", " train = FileSystemInput(\n", " file_system_id=file_system_id,\n", " file_system_type=file_system_type,\n", " directory_path=base_path,\n", " file_system_access_mode=\"rw\",\n", " )\n", "\n", " data_channels = {\"train\": train, \"test\": train}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set hyperparameters, metric definitions, and MPI options\n", "The following `hyperparameters` dictionary passes arguments to the training script (`train.py`) and set the model parallel configuration when creating the training job.\n", "\n", "You can also add custom `mpi` flags. By default, we have `--mca btl_vader_single_copy_mechanism none` to remove unnecessary logs.\n", "\n", "Next, we add a base metric definitions to enable the metric upload in SageMaker. You can add any further metric definitions.\n", "\n", "Note that we add the `sharded_data_parallel_degree` parameter to the `hyperparameter` dictionary. This will be parsed and used when we configure a SageMaker PyTorch estimator to activate sharded data parallelism.\n", "\n", "Also note that we add the `fine_tune` parameter that activates the code lines for fine-tuning in the script `train.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {\n", " \"max_steps\": 100,\n", " \"seed\": 12345,\n", " \"fp16\": 0,\n", " \"bf16\": 1,\n", " \"lr\": 2.0e-4,\n", " \"lr_decay_iters\": 125000,\n", " \"min_lr\": 0.00001,\n", " \"lr-decay-style\": \"linear\",\n", " \"warmup\": 0.01,\n", " \"num_kept_checkpoints\": 5,\n", " \"checkpoint_freq\": 200,\n", " \"logging_freq\": 1,\n", " \"save_final_full_model\": 0,\n", " \"delayed_param\": 1,\n", " \"use_distributed_transformer\": 1,\n", " \"offload_activations\": 0,\n", " \"gradient_accumulation\": 1,\n", " \"validation_freq\": 200,\n", " \"train_batch_size\": 10,\n", " \"val_batch_size\": 4,\n", " \"flash_attention\": 1,\n", " \"zipped_data\": 0,\n", " \"epochs\": 100,\n", " # parameters for activating the fine tuning mode\n", " \"fine_tune\": 1,\n", " \"model_name\": \"gpt2-xl\",\n", " # parameters for sharded data parallelism\n", " \"sharded_data_parallel_degree\": 8,\n", "}\n", "\n", "if use_fsx:\n", " # make sure to update paths for training-dir and test-dir based on the paths of datasets in fsx\n", " # If you want to resume training, set checkpoint-dir to the same path as a previous job.\n", " SM_TRAIN_DIR = \"/opt/ml/input/data/train\"\n", " hyperparameters[\"checkpoint-dir\"] = f\"{SM_TRAIN_DIR}/checkpointdir-job2\"\n", " hyperparameters[\"model-dir\"] = f\"{SM_TRAIN_DIR}/modeldir-job2\"\n", " hyperparameters[\"training-dir\"] = f\"{SM_TRAIN_DIR}/datasets/pytorch_gpt2/train_synthetic\"\n", " hyperparameters[\"test-dir\"] = f\"{SM_TRAIN_DIR}/datasets/pytorch_gpt2/val_synthetic\"\n", "\n", "# The checkpoint path (hyperparameters['checkpoint-dir'] or checkpoint_s3_uri) is not unique per job.\n", "# You need to modify as needed for different runs.\n", "# If same path is used for unrelated runs, this may increase time when downloading unnecessary checkpoints,\n", "# and cause conflicts when loading checkpoints.\n", "\n", "mpioptions = \"-x NCCL_DEBUG=WARN -x SMDEBUG_LOG_LEVEL=ERROR \"\n", "mpioptions += (\n", " \"-x SMP_DISABLE_D2D=1 -x SMP_D2D_GPU_BUFFER_SIZE_BYTES=1 -x SMP_NCCL_THROTTLE_LIMIT=1 \"\n", ")\n", "mpioptions += \"-x FI_EFA_USE_DEVICE_RDMA=1 -x FI_PROVIDER=efa -x RDMAV_FORK_SAFE=1\"\n", "\n", "metric_definitions = [\n", " {\"Name\": \"base_metric\", \"Regex\": \"<><><><><><>\"}\n", "] # Add your custom metric definitions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the model configuration. Specify one from `gpt2-30b`, `gpt2-xl` and `gpt2-small`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config = \"gpt2-xl\"\n", "\n", "if model_config == \"gpt2-30b\":\n", " model_params = {\n", " \"max_context_width\": 2048,\n", " \"hidden_width\": 7168,\n", " \"num_layers\": 48,\n", " \"num_heads\": 64,\n", " }\n", "\n", "elif model_config == \"gpt2-xl\":\n", " # 1.5B\n", " model_params = {\n", " \"max_context_width\": 2048,\n", " \"hidden_width\": 1536,\n", " \"num_layers\": 48,\n", " \"num_heads\": 24,\n", " }\n", "elif model_config == \"gpt2-small\":\n", " model_params = {\n", " \"max_context_width\": 2048,\n", " \"hidden_width\": 768,\n", " \"num_layers\": 12,\n", " \"num_heads\": 12,\n", " }\n", "else:\n", " raise RuntimeError(\"Unknown model config\")\n", "\n", "for k, v in model_params.items():\n", " hyperparameters[k] = v" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Specify essential parameters for a SageMaker Training job\n", "\n", "Next, you use the [`SageMaker Estimator class`](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) to define a SageMaker Training Job, passing values through the following parameters for training job name, the number of EC2 instances, the instance type, and the size of the volume attached to the instances. \n", "\n", "* `instance_count`\n", "* `instance_type`\n", "* `volume_size`\n", "* `base_job_name`\n", "\n", "### Update the type and the number of EC2 instance to use\n", "\n", "The instance type and the number of instances you specify to the `instance_type` and `instance_count` parameters, respectively, determine the total number of GPUs (world size).\n", "\n", "$$ \\text{(world size) = (the number of GPUs on a single instance)}\\times\\text{(the number of instances)}$$\n", "\n", "- For GPT-2 with 30-billion parameters, you need at least 16 `ml.p4d.24xlarge` instances.\n", "- For GPT-2 xl, use 1 `ml.p4d.24xlarge` at least.\n", "- For GPT-2 small, use 1 `ml.p3.16xlarge` at least." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "instance_type = \"ml.p4d.24xlarge\"\n", "instance_count = 1\n", "\n", "# set to the number of GPUs on that instance\n", "processes_per_host = 8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To look up the number of GPUs of different instance types, see [Amazon EC2 Instance Types](https://aws.amazon.com/ec2/instance-types/). Use the section **Accelerated Computing** to see general purpose GPU instances. Note that, for example, a given instance type `p4d.24xlarge` has a corresponding instance type `ml.p4d.24xlarge` in SageMaker.\n", "For SageMaker supported `ml` instances and cost information, see [Amazon SageMaker Pricing](https://aws.amazon.com/sagemaker/pricing/). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Specify a base job name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "machine_str = instance_type.split(\".\")[1] + instance_type.split(\".\")[2][:3]\n", "sharding_degree = hyperparameters[\"sharded_data_parallel_degree\"]\n", "base_job_name = (\n", " f'smp-{model_config}-{machine_str}-sdp{sharding_degree}-bs{hyperparameters[\"train_batch_size\"]}'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not use_fsx:\n", " # If you want to resume training, set checkpoint_s3_uri to the same path as a previous job.\n", " # Previous checkpoint to load must have same model config.\n", " checkpoint_bucket = f\"s3://sagemaker-{region}-{account}/\"\n", " checkpoint_s3_uri = (\n", " f\"{checkpoint_bucket}/experiments/gpt_synthetic_simpletrainer_checkpoints/{base_job_name}/\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"base_job_name: {base_job_name} checkpoint_s3_uri: {checkpoint_s3_uri}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a SageMaker PyTorch estimator\n", "\n", "The following cell constructs a PyTorch estimator using the parameters defined above. To see how the SageMaker APIs and functions are applied to the script, see the `train.py` file." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kwargs = {}\n", "if use_fsx:\n", " # Use the security group and subnet that was used to create the fsx filesystem\n", " kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", " kwargs[\"subnets\"] = [fsx_subnet]\n", "\n", "smp_estimator = PyTorch(\n", " entry_point=\"train.py\",\n", " source_dir=os.getcwd(),\n", " role=role,\n", " instance_type=instance_type,\n", " instance_count=instance_count,\n", " sagemaker_session=sagemaker_session,\n", " distribution={\n", " \"mpi\": {\n", " \"enabled\": True,\n", " \"processes_per_host\": processes_per_host,\n", " \"custom_mpi_options\": mpioptions,\n", " },\n", " \"smdistributed\": {\n", " \"modelparallel\": {\n", " \"enabled\": True,\n", " \"parameters\": {\n", " \"ddp\": True,\n", " \"skip_tracing\": True,\n", " \"delayed_parameter_initialization\": hyperparameters[\"delayed_param\"] > 0,\n", " \"offload_activations\": hyperparameters[\"offload_activations\"] > 0,\n", " \"sharded_data_parallel_degree\": hyperparameters[\"sharded_data_parallel_degree\"],\n", " \"fp16\": hyperparameters[\"fp16\"] > 0,\n", " \"bf16\": hyperparameters[\"bf16\"] > 0,\n", " # partitions is a required param in the current SM SDK so it needs to be passed,\n", " \"partitions\": 1,\n", " },\n", " }\n", " },\n", " },\n", " framework_version=\"1.13\",\n", " py_version=\"py39\",\n", " output_path=s3_output_bucket,\n", " checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n", " checkpoint_local_path=hyperparameters[\"checkpoint-dir\"] if use_fsx else None,\n", " metric_definitions=metric_definitions,\n", " hyperparameters=hyperparameters,\n", " debugger_hook_config=False,\n", " disable_profiler=True,\n", " base_job_name=base_job_name,\n", " **kwargs,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, run the `estimator.fit` method to launch the SageMaker training job of fine-tuning the GPT-2 model with sharded data parallelism." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "smp_estimator.fit(inputs=data_channels, logs=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Accessing the Training Logs\n", "\n", "You can access the training logs from [Amazon CloudWatch](https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/WhatIsCloudWatch.html). Make sure to look at the logs of **algo-1** because that is the main node whose output stream has the training job logs.\n", "\n", "You can use CloudWatch to track SageMaker GPU and memory utilization during training and inference. To view the metrics and logs that SageMaker writes to CloudWatch, see [SageMaker Jobs and Endpoint Metrics](https://docs.aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html#cloudwatch-metrics-jobs) in the Amazon SageMaker Developer Guide.\n", "\n", "If you are a new user of CloudWatch, see [Getting Started with Amazon CloudWatch](https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/GettingStarted.html). \n", "\n", "For additional information on monitoring and analyzing Amazon SageMaker training jobs, see [Monitor and Analyze Training Jobs Using Metrics](https://docs.aws.amazon.com/sagemaker/latest/dg/training-metrics.html).\n", "\n", "## Deploying Trained Model for Inference\n", "\n", "In most cases, a trained model can be deployed on a single device for inference because inference only requires a small amount of memory. You can use the SMP API to create a single, unified model after training: the [smp.DistributedModel.save_model()](https://sagemaker.readthedocs.io/en/stable/api/training/smp_versions/latest/smd_model_parallel_tensorflow.html#smp.DistributedModel.save_model) method for TensorFlow, and the [smp.save()](https://sagemaker.readthedocs.io/en/stable/api/training/smp_versions/latest/smd_model_parallel_pytorch.html#apis-for-saving-and-loading) function for PyTorch.\n", "\n", "After you build and train your models, you can deploy them to get predictions in one of two ways:\n", "\n", "* To set up a persistent endpoint to get predictions from your models, use SageMaker hosting services. For an overview on deploying a single model or multiple models with SageMaker hosting services, see [Deploy a Model on SageMaker Hosting Services](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-deployment.html#how-it-works-hosting).\n", "* To get predictions for an entire dataset, use SageMaker batch transform. For an overview on deploying a model with SageMaker Batch Transform, see [Get Inferences for an Entire Dataset with Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-batch.html).\n", "\n", "To learn more about deploying models for inference using SageMaker, see [Deploy Models for Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html). \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n", "\n", "![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/training|distributed_training|pytorch|model_parallel|gpt2|smp-train-gpt-simple-sharded-data-parallel.ipynb)\n" ] } ], "metadata": { "hide_input": false, "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_pytorch_p310", "language": "python", "name": "conda_pytorch_p310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 4 }