{ "cells": [ { "cell_type": "markdown", "id": "86e62a9e-44cc-43ca-acd3-3f3318036c5f", "metadata": {}, "source": [ "# Train Stable Diffusion on SageMaker distributed training\n", "For this notebook to work properly, you'll need to make sure that FSx for Lustre has actually inherited all of your available data. To do this, open the Lustre console view, select the association id that points to your preferred S3 path, then click \"actions,\" \"create an import task.\" This will start an action in Lustre to import all of the data in your S3 path onto Lustre.\n", "\n", "Specifically this script is expecting multiple folders named \"part-{}\"." ] }, { "cell_type": "markdown", "id": "9bb21f3c-dbb5-48e7-82f1-af7957189795", "metadata": {}, "source": [ "### Step 1. Point to FSx for Lustre" ] }, { "cell_type": "code", "execution_count": 2, "id": "b2ecf912-9653-4773-ac0f-9582ea9dde5e", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.inputs import FileSystemInput\n", "\n", "# Specify FSx Lustre file system id.\n", "file_system_id = \"fs-0a83907c9c9c7b8f0\"\n", "\n", "# Specify the SG and subnet used by the FSX, these are passed to SM Estimator so jobs use this as well\n", "fsx_security_group_id = \"sg-ac4f1cb5\"\n", "fsx_subnet = \"subnet-be054be1\"\n", "\n", "# Specify directory path for input data on the file system.\n", "# You need to provide normalized and absolute path below.\n", "# Your mount name can be provided by you when creating fsx, or generated automatically.\n", "# You can find this mount_name on the FSX page in console.\n", "# Example of fsx generated mount_name: \"3x5lhbmv\"\n", "base_path = \"/yflftbev\"\n", "\n", "# Specify your file system type.\n", "file_system_type = \"FSxLustre\"\n", "\n", "train = FileSystemInput(\n", " file_system_id=file_system_id,\n", " file_system_type=file_system_type,\n", " directory_path=base_path,\n", " file_system_access_mode=\"rw\",\n", ")\n", "\n", "data_channels = {\"train\": train}" ] }, { "cell_type": "code", "execution_count": 3, "id": "31bb9181-3117-4dfe-b398-6dfe2e6648d7", "metadata": { "tags": [] }, "outputs": [], "source": [ "kwargs = {}\n", "# Use the security group and subnet that was used to create the fsx filesystem\n", "kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", "kwargs[\"subnets\"] = [fsx_subnet]" ] }, { "cell_type": "markdown", "id": "0642b0b5-3509-40f7-bbb1-35194a7749b4", "metadata": {}, "source": [ "### Step 2. Process data and build a json index\n", "In my implementation of this, I actually built my own data loader function that used a custom json lines file. This saved a lot of time in loading the data, because rather than needing to `ls` all of my files, I simply had them predefined. You might think that's not a big deal, but once you're looking at more than a few million image/text pairs, it adds up!\n", "\n", "Details on my full case study are [available here](https://medium.com/@emilywebber/how-i-trained-10tb-for-stable-diffusion-on-sagemaker-39dcea49ce32)." ] }, { "cell_type": "code", "execution_count": 4, "id": "e7ee6d84-316a-42c0-aef3-3ca54248f5ad", "metadata": { "tags": [] }, "outputs": [], "source": [ "!mkdir stable_scripts" ] }, { "cell_type": "code", "execution_count": 34, "id": "d08927d8-106d-4b63-b0da-2cd8b8f7e40c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting stable_scripts/process_data.py\n" ] } ], "source": [ "%%writefile stable_scripts/process_data.py\n", "\n", "import argparse\n", "import math\n", "import os\n", "import random\n", "from pathlib import Path\n", "from typing import Optional\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "import torch.utils.checkpoint\n", "\n", "from os import listdir\n", "import os\n", "from skimage import io\n", "\n", "import PIL\n", "from accelerate import Accelerator\n", "from accelerate.logging import get_logger\n", "from accelerate.utils import set_seed\n", "from datasets import load_dataset\n", "from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n", "from diffusers.optimization import get_scheduler\n", "from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n", "from huggingface_hub import HfFolder, Repository, whoami\n", "from torchvision import transforms\n", "from torchvision.io import ImageReadMode, read_image\n", "from tqdm.auto import tqdm\n", "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n", "import requests\n", "from PIL import Image\n", "# from mpi4py import MPI\n", "\n", "from datasets import load_dataset \n", "from datasets import Dataset\n", "from datasets import DatasetDict\n", "\n", "import glob\n", "import multiprocessing as mp\n", "from multiprocessing import Pool\n", "import pandas as pd\n", "import json\n", "\n", "def parse_args():\n", " \n", " parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n", "\n", " parser.add_argument(\"--train_data_dir\", type=str, default=os.environ['SM_HP_TRAIN_DATA_DIR'], help=\"A folder containing the training data.\")\n", " \n", " parser.add_argument(\"--function\", type=str, default=os.environ['SM_HP_FUNCTION'], help=\"A generic argument to determine function of this script. Could be unzip and/or build pointer\")\n", " \n", " parser.add_argument(\"--model_dir\", type=str, default=os.environ['SM_MODEL_DIR'], help=\"SM training path for model, will copy to S3 after job completes\")\n", " \n", " parser.add_argument(\"--index_name\", type=str, default=os.environ['SM_HP_INDEX_NAME'], help=\"To point to the name of the index file\")\n", " \n", " args = parser.parse_args()\n", " \n", " return args\n", "\n", "def read_caption(path_to_image):\n", " '''\n", " Takes a full path and full object number, returns the string content\n", " '''\n", " path_to_text = path_to_image.replace('jpg', 'txt')\n", " \n", " with open(path_to_text) as f:\n", " data = f.readlines()\n", " return data\n", " \n", "def unzip_part(full_part_path):\n", " '''\n", " Adds all images for one part to the dataset dictionary list\n", " ''' \n", " \n", " print ('Working on part num: {}'.format(full_part_path))\n", " \n", " img_list = glob.glob(\"{}/*.jpg\".format(full_part_path))\n", "\n", " print ('This part now has {} images!'.format(len(img_list)))\n", "\n", " unzip = True\n", "\n", " if unzip:\n", " \n", " # look for all tar balls in this part path\n", " tar_balls = glob.glob(\"{}/*.tar\".format(full_part_path))\n", " \n", " # this would be the place to try and add multiprocessing \n", " for tball in tar_balls:\n", "\n", " # unzip the folder there, checks per file if already unzipped\n", " cmd = 'tar -xf {} --skip-old-files --directory {}'.format(tball, full_part_path)\n", " os.system(cmd)\n", " \n", "def write_index(full_part_path):\n", " '''\n", " Takes one full part, loops through it, grabs each image/text pair, writes them to a json lines file.\n", " '''\n", "\n", " print ('Working on part num: {}'.format(full_part_path))\n", "\n", " img_list = glob.glob(\"{}/*.jpg\".format(full_part_path))\n", "\n", " print ('This part now has {} images!'.format(len(img_list)))\n", "\n", " index_path = full_part_path.split('part')[0] + 'data_index.jsonl'\n", "\n", " print ('Writing index to {}'.format(index_path))\n", "\n", " with open(index_path, 'a') as fp:\n", "\n", " for path_to_image in img_list:\n", "\n", " try:\n", " caption = read_caption(path_to_image)\n", " pair = {\"image\":path_to_image, \"caption\": caption[0]} \n", " json.dump(pair, fp)\n", " fp.write('\\n')\n", " \n", " except:\n", " continue\n", " \n", "if __name__ == \"__main__\": \n", "\n", " args = parse_args()\n", " \n", " print ('Train data dir is here: {}'.format(args.train_data_dir))\n", " \n", " part_list = glob.glob(\"{}/part-*\".format(args.train_data_dir)) \n", " \n", " print ('Found {} parts to work on, starting multiprocessing pool'.format(len(part_list)))\n", " \n", " cpus = mp.cpu_count()\n", " \n", " with Pool(cpus) as p:\n", " \n", " if 'unzip' in args.function:\n", " p.map(unzip_part, part_list)\n", "\n", " if 'index' in args.function:\n", " p.map(write_index, part_list)\n", " \n", " cmd = 'cp {}/data_index.jsonl {}'.format(args.train_data_dir, args.model_dir)\n", "\n", " os.system(cmd)" ] }, { "cell_type": "markdown", "id": "d30c80b0-1c56-47d3-8299-de7cc7e3c188", "metadata": {}, "source": [ "#### Now let's run that on SageMaker training" ] }, { "cell_type": "code", "execution_count": 35, "id": "395d5ecf-1528-4337-8aae-a200863b693f", "metadata": { "tags": [] }, "outputs": [], "source": [ "version = 'v1'\n", "\n", "# points to an image I've made and am hosting for you to use\n", "image_uri = '220691188711.dkr.ecr.us-east-1.amazonaws.com/stable-diffusion:{}'.format(version )" ] }, { "cell_type": "code", "execution_count": 39, "id": "f22f8449-b1fc-4b57-92ef-5d245fd2282e", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker.pytorch import PyTorch\n", "\n", "sess = sagemaker.Session()\n", "role = sagemaker.get_execution_role()\n", "\n", "bucket = sess.default_bucket()\n", "\n", "hyperparameters = {'train_data_dir':'/opt/ml/input/data/train/fsx-data',\n", " 'function':'unzip,index', \n", " 'index_name': 'data_index.jsonl'}\n", "\n", "estimator = PyTorch(\n", " entry_point=\"process_data.py\",\n", " base_job_name=\"stable-diffusion-process-data\",\n", " role=role,\n", " image_uri = image_uri,\n", " source_dir=\"stable_scripts\",\n", " # configures the SageMaker training resource, you can increase as you need\n", " instance_count=1,\n", " instance_type=\"ml.c5n.18xlarge\",\n", " py_version=\"py38\",\n", " framework_version = '1.10',\n", " sagemaker_session=sess,\n", " hyperparameters = hyperparameters,\n", " debugger_hook_config=False,\n", " # enable warm pools for 60 minutes, useful for debugging\n", " keep_alive_period_in_seconds = 60 * 60,\n", " **kwargs\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "id": "af3c7ab7-7282-4e4b-915a-c55da4f1b64e", "metadata": { "tags": [] }, "outputs": [], "source": [ "estimator.fit(inputs = data_channels, wait=False)" ] }, { "cell_type": "markdown", "id": "e91bc1b9-f0ec-4a85-8f92-3c15edc66bf4", "metadata": {}, "source": [ "### Step 3. Test the index and data loader locally" ] }, { "cell_type": "code", "execution_count": null, "id": "7b9ac323-48aa-438d-a193-7fada3006291", "metadata": {}, "outputs": [], "source": [ "!aws s3 cp " ] }, { "cell_type": "code", "execution_count": null, "id": "bf98d8bd-6406-4d30-80b2-4d126ffa8ed4", "metadata": {}, "outputs": [], "source": [ "# make sure this works locally. if not, you'll waste a ton of GPU time. \n", "def load_index(args):\n", " \n", " print ('loading the index')\n", " \n", " index_path = args.train_data_dir + '/' + args.index_name\n", "\n", " print ('pointing to', index_path)\n", " \n", " df = pd.read_json(index_path, lines=True)\n", " \n", " print ('read the dataframe, shape like', df.shape)\n", " \n", " dataset = Dataset.from_pandas(df)\n", " \n", " rt = DatasetDict({'train':dataset})\n", " \n", " print ('read the DatasetDict, columns like', rt.column_names)\n", " \n", " return rt " ] }, { "cell_type": "markdown", "id": "56b455b9-f0ce-4293-8204-326b10aa6388", "metadata": {}, "source": [ "### Step 4. Run the full job on SageMaker" ] }, { "cell_type": "code", "execution_count": 33, "id": "45423240-55d3-4240-b322-4f878bfa972b", "metadata": { "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting stable_scripts/finetune.py\n" ] } ], "source": [ "%%writefile stable_scripts/finetune.py\n", "import argparse\n", "import math\n", "import os\n", "import random\n", "from pathlib import Path\n", "from typing import Optional\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "import torch.utils.checkpoint\n", "\n", "from os import listdir\n", "import os\n", "from skimage import io\n", "\n", "import PIL\n", "from accelerate import Accelerator\n", "from accelerate.logging import get_logger\n", "from accelerate.utils import set_seed\n", "from datasets import load_dataset\n", "\n", "from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n", "\n", "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n", "\n", "from diffusers.optimization import get_scheduler\n", "from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n", "from huggingface_hub import HfFolder, Repository, whoami\n", "from torchvision import transforms\n", "from torchvision.io import ImageReadMode, read_image\n", "from tqdm.auto import tqdm\n", "import requests\n", "from PIL import Image\n", "from mpi4py import MPI\n", "\n", "from datasets import load_dataset, Dataset, DatasetDict \n", "import json\n", "import glob\n", "import multiprocessing as mp\n", "from multiprocessing import Pool\n", "import pandas as pd\n", "\n", "logger = get_logger(__name__)\n", "\n", "def parse_args():\n", " parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n", " parser.add_argument(\n", " \"--pretrained_model_name_or_path\",\n", " type=str,\n", " default=None,\n", " required=True,\n", " help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n", " )\n", " parser.add_argument(\n", " \"--dataset_name\",\n", " type=str,\n", " default=None,\n", " help=(\n", " \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n", " \" dataset).\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--dataset_config_name\",\n", " type=str,\n", " default=None,\n", " help=\"The config of the Dataset, leave as None if there's only one config.\",\n", " )\n", " parser.add_argument(\"--train_data_dir\", type=str, default=os.environ['SM_CHANNEL_TRAINING'], help=\"A folder containing the training data.\")\n", " \n", " parser.add_argument(\n", " \"--validation_data_dir\", type=str, default=None, help=\"A folder containing the validation data.\"\n", " )\n", " parser.add_argument(\n", " \"--image_column\", type=str, default = os.environ[\"SM_HP_IMAGE_COLUMN\"], help=\"The column of the dataset containing an image.\"\n", " )\n", " parser.add_argument(\n", " \"--caption_column\",\n", " type=str,\n", " default=\"text\",\n", " help=\"The column of the dataset containing a caption or a list of captions.\",\n", " )\n", " parser.add_argument(\n", " \"--max_train_samples\",\n", " type=int,\n", " default=None,\n", " help=(\n", " \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n", " \"value if set.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--max_eval_samples\",\n", " type=int,\n", " default=None,\n", " help=(\n", " \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n", " \"value if set.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--train_val_split\",\n", " type=float,\n", " default=0.15,\n", " help=\"Percent to split off of train for validation\",\n", " )\n", " parser.add_argument(\n", " \"--output_dir\",\n", " type=str,\n", " default=\"sd-model-finetuned\",\n", " help=\"The output directory where the model predictions and checkpoints will be written.\",\n", " )\n", " parser.add_argument(\n", " \"--cache_dir\",\n", " type=str,\n", " default=None,\n", " help=\"The directory where the downloaded models and datasets will be stored.\",\n", " )\n", " parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n", " parser.add_argument(\n", " \"--resolution\",\n", " type=int,\n", " default=512,\n", " help=(\n", " \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n", " \" resolution\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--center_crop\",\n", " action=\"store_true\",\n", " help=\"Whether to center crop images before resizing to resolution (if not set, use random crop)\",\n", " )\n", " parser.add_argument(\n", " \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n", " )\n", " parser.add_argument(\n", " \"--eval_batch_size\", type=int, default=16, help=\"Batch size (per device) for the eval dataloader.\"\n", " )\n", " parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n", " parser.add_argument(\n", " \"--max_train_steps\",\n", " type=int,\n", " default=-1,\n", " help=\"Total number of training steps to perform. If provided, overrides num_train_epochs.\",\n", " )\n", " parser.add_argument(\n", " \"--gradient_accumulation_steps\",\n", " type=int,\n", " default=1,\n", " help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n", " )\n", " parser.add_argument(\n", " \"--learning_rate\",\n", " type=float,\n", " default=1e-4,\n", " help=\"Initial learning rate (after the potential warmup period) to use.\",\n", " )\n", " parser.add_argument(\n", " \"--scale_lr\",\n", " action=\"store_true\",\n", " default=True,\n", " help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n", " )\n", " parser.add_argument(\n", " \"--lr_scheduler\",\n", " type=str,\n", " default=\"constant\",\n", " help=(\n", " 'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n", " ' \"constant\", \"constant_with_warmup\"]'\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n", " )\n", " parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n", " parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n", " parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n", " parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n", " parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n", " parser.add_argument(\n", " \"--use_auth_token\",\n", " action=\"store_true\",\n", " help=(\n", " \"Will use the token generated when running `huggingface-cli login` (necessary to use this script with\"\n", " \" private models).\"\n", " ),\n", " )\n", " parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n", " parser.add_argument(\n", " \"--hub_model_id\",\n", " type=str,\n", " default=None,\n", " help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n", " )\n", " parser.add_argument(\n", " \"--logging_dir\",\n", " type=str,\n", " default=\"logs\",\n", " help=(\n", " \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n", " \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--mixed_precision\",\n", " type=str,\n", " default=\"no\",\n", " choices=[\"no\", \"fp16\", \"bf16\"],\n", " help=(\n", " \"Whether to use mixed precision. Choose\"\n", " \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n", " \"and an Nvidia Ampere GPU.\"\n", " ),\n", " )\n", " parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n", "\n", " parser.add_argument(\"--index_name\", type=str, default=os.environ['SM_HP_INDEX_NAME'], help=\"To point to the name of the index file on FSx for Lustre\")\n", " \n", " parser.add_argument(\"--n_rows\", type=int, default=os.environ['SM_HP_N_ROWS'], help=\"Defines the number of rows to read from the index file\")\n", " \n", " args = parser.parse_args()\n", " \n", " env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n", " \n", " if int(os.environ.get(\"OMPI_COMM_WORLD_LOCAL_RANK\",-1)) >= 0:\n", " env_local_rank = int(os.environ.get(\"OMPI_COMM_WORLD_LOCAL_RANK\"))\n", " os.environ['LOCAL_RANK'] = str(env_local_rank)\n", " args.local_rank = env_local_rank\n", " os.environ['RANK'] = os.environ.get(\"OMPI_COMM_WORLD_RANK\")\n", " os.environ['WORLD_SIZE'] = os.environ.get(\"OMPI_COMM_WORLD_SIZE\")\n", "\n", " # Sanity checks\n", " if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None:\n", " raise ValueError(\"Need either a dataset name or a training/validation folder.\")\n", "\n", " return args\n", "\n", "def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):\n", " if token is None:\n", " token = HfFolder.get_token()\n", " if organization is None:\n", " username = whoami(token)[\"name\"]\n", " return f\"{username}/{model_id}\"\n", " else:\n", " return f\"{organization}/{model_id}\"\n", "\n", "\n", "def freeze_params(params):\n", " for param in params:\n", " param.requires_grad = False\n", "\n", "dataset_name_mapping = {\n", " \"image_caption_dataset.py\": (\"image_path\", \"caption\"),\n", "}\n", " \n", "def main():\n", " \n", " args = parse_args()\n", " \n", " logging_dir = os.path.join(args.output_dir, args.logging_dir)\n", "\n", " accelerator = Accelerator(\n", " gradient_accumulation_steps=args.gradient_accumulation_steps,\n", " mixed_precision=args.mixed_precision\n", " )\n", "\n", " # If passed along, set the training seed now.\n", " if args.seed is not None:\n", " set_seed(args.seed)\n", "\n", " # Handle the repository creation\n", " if accelerator.is_main_process:\n", " \n", " if args.push_to_hub:\n", " if args.hub_model_id is None:\n", " repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)\n", " else:\n", " repo_name = args.hub_model_id\n", " repo = Repository(args.output_dir, clone_from=repo_name)\n", "\n", " with open(os.path.join(args.output_dir, \".gitignore\"), \"w+\") as gitignore:\n", " if \"step_*\" not in gitignore:\n", " gitignore.write(\"step_*\\n\")\n", " if \"epoch_*\" not in gitignore:\n", " gitignore.write(\"epoch_*\\n\")\n", " elif args.output_dir is not None:\n", " os.makedirs(args.output_dir, exist_ok=True)\n", "\n", " # Load models and create wrapper for stable diffusion\n", " tokenizer = CLIPTokenizer.from_pretrained(\n", " args.pretrained_model_name_or_path,\n", " subfolder=\"tokenizer\",\n", " use_auth_token=args.use_auth_token,\n", " )\n", " text_encoder = CLIPTextModel.from_pretrained(\n", " args.pretrained_model_name_or_path, subfolder=\"text_encoder\", use_auth_token=args.use_auth_token\n", " )\n", " vae = AutoencoderKL.from_pretrained(\n", " args.pretrained_model_name_or_path, subfolder=\"vae\", use_auth_token=args.use_auth_token\n", " )\n", " unet = UNet2DConditionModel.from_pretrained(\n", " args.pretrained_model_name_or_path, subfolder=\"unet\", use_auth_token=args.use_auth_token\n", " )\n", "\n", " # Freeze vae and text_encoder\n", " freeze_params(vae.parameters())\n", " freeze_params(text_encoder.parameters())\n", "\n", " if args.scale_lr:\n", " args.learning_rate = (\n", " args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n", " )\n", "\n", " # Initialize the optimizer\n", " optimizer = torch.optim.AdamW(\n", " unet.parameters(),\n", " lr=args.learning_rate,\n", " betas=(args.adam_beta1, args.adam_beta2),\n", " weight_decay=args.adam_weight_decay,\n", " eps=args.adam_epsilon,\n", " )\n", "\n", " # TODO (patil-suraj): load scheduler using args\n", " noise_scheduler = DDPMScheduler(\n", " beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000, tensor_format=\"pt\"\n", " )\n", "\n", " # Get the datasets: you can either provide your own training and evaluation files (see below)\n", " # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n", "\n", " # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n", " # download the dataset.\n", " if (args.dataset_name is not None) and ('.' not in args.dataset_name):\n", " # Downloading and loading a dataset from the hub.\n", " dataset = load_dataset(\n", " args.dataset_name,\n", " args.dataset_config_name,\n", " cache_dir=args.cache_dir,\n", " use_auth_token=True if args.use_auth_token else None,\n", " )\n", " elif (args.dataset_name is not None):\n", " dataset = load_dataset('parquet',data_files=args.dataset_name)\n", " \n", " else:\n", " data_files = {}\n", " if args.train_data_dir is not None:\n", " data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n", " if args.validation_data_dir is not None:\n", " data_files[\"validation\"] = os.path.join(args.validation_data_dir, \"**\") \n", "\n", " train_transforms = transforms.Compose(\n", " [\n", " transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n", " transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5], [0.5]),\n", " ]\n", " )\n", " val_transforms = transforms.Compose(\n", " [\n", " transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n", " transforms.CenterCrop(args.resolution),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5], [0.5]),\n", " ]\n", " )\n", "\n", " # this function expects the image path \n", " def preprocess_train(examples):\n", " images = [Image.open(image).convert(\"RGB\") for image in examples[image_column]]\n", " examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n", " examples[\"input_ids\"] = tokenize_captions(examples)\n", "\n", " return examples\n", "\n", " def preprocess_val(examples):\n", " images = [Image.open(image).convert(\"RGB\") for image in examples[image_column]]\n", " examples[\"pixel_values\"] = [val_transforms(image) for image in images]\n", " examples[\"input_ids\"] = tokenize_captions(examples, is_train=False)\n", " return examples\n", "\n", " with accelerator.main_process_first():\n", " \n", " print ('triggered main function')\n", " \n", " dataset = load_index_dataset(args.train_data_dir, args.n_rows)\n", "\n", " # If we don't have a validation split, split off a percentage of train as validation.\n", " args.train_val_split = None if \"validation\" in dataset.keys() else args.train_val_split\n", "\n", " if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:\n", " split = dataset[\"train\"].train_test_split(args.train_val_split)\n", " dataset[\"train\"] = split[\"train\"]\n", " dataset[\"validation\"] = split[\"test\"]\n", "\n", " # Preprocessing the datasets.\n", " \n", " # We need to tokenize inputs and targets.\n", " column_names = dataset[\"train\"].column_names\n", "\n", " # 6. Get the column names for input/target.\n", " dataset_columns = dataset_name_mapping.get(args.dataset_name, None)\n", "\n", " if args.image_column is None:\n", " image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n", " else:\n", " image_column = args.image_column\n", " if image_column not in column_names:\n", " raise ValueError(\n", " f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n", " )\n", " if args.caption_column is None:\n", " caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n", " else:\n", " caption_column = args.caption_column\n", " if caption_column not in column_names:\n", " raise ValueError(\n", " f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n", " )\n", "\n", " # Preprocessing the datasets.\n", " # We need to tokenize input captions and transform the images.\n", " def tokenize_captions(examples, is_train=True):\n", " captions = []\n", " for caption in examples[caption_column]:\n", " if isinstance(caption, str):\n", " captions.append(caption)\n", " elif isinstance(caption, (list, np.ndarray)):\n", " # take a random caption if there are multiple\n", " captions.append(random.choice(caption) if is_train else caption[0])\n", " else:\n", " raise ValueError(\n", " f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n", " )\n", " input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids\n", " return input_ids\n", " \n", " if args.max_train_samples is not None:\n", " dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n", " \n", " \n", " # Set the training transforms\n", " train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n", " \n", " \n", " if args.max_eval_samples is not None:\n", " dataset[\"validation\"] = dataset[\"validation\"].shuffle(seed=args.seed).select(range(args.max_eval_samples))\n", " # Set the validation transforms\n", " eval_dataset = dataset[\"validation\"].with_transform(preprocess_val)\n", "\n", " def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " input_ids = [example[\"input_ids\"] for example in examples]\n", " padded_tokens = tokenizer.pad(\n", " {\"input_ids\": input_ids},\n", " padding=\"max_length\",\n", " max_length=tokenizer.model_max_length,\n", " return_tensors=\"pt\",\n", " )\n", " return {\n", " \"pixel_values\": pixel_values,\n", " \"input_ids\": padded_tokens.input_ids,\n", " \"attention_mask\": padded_tokens.attention_mask,\n", " }\n", "\n", " \n", " train_dataloader = torch.utils.data.DataLoader(\n", " train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=4\n", " )\n", " eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size, num_workers=4)\n", "\n", " # Scheduler and math around the number of training steps.\n", " overrode_max_train_steps = False\n", " num_update_steps_per_epoch = math.ceil(len(train_dataloader) / (args.gradient_accumulation_steps))\n", " if args.max_train_steps <= 0:\n", " args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n", " overrode_max_train_steps = True\n", "\n", " lr_scheduler = get_scheduler(\n", " args.lr_scheduler,\n", " optimizer=optimizer,\n", " num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n", " num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n", " )\n", "\n", " \n", " unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n", " unet, optimizer, train_dataloader, lr_scheduler\n", " )\n", "\n", " # Move vae and unet to device\n", " vae.to(accelerator.device)\n", " text_encoder.to(accelerator.device)\n", "\n", " # Keep vae and unet in eval model as we don't train these\n", " vae.eval()\n", " text_encoder.eval()\n", " # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", " num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n", " if overrode_max_train_steps:\n", " args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n", " # Afterwards we recalculate our number of training epochs\n", " args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n", "\n", " # We need to initialize the trackers we use, and also store our configuration.\n", " # The trackers initializes automatically on the main process.\n", " if accelerator.is_main_process:\n", " accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n", "\n", " # Train!\n", " total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n", " if accelerator.is_main_process:\n", " logger.info(\"***** Running training *****\")\n", " logger.info(f\" Num examples = {len(train_dataset)}\")\n", " logger.info(f\" Num Epochs = {args.num_train_epochs}\")\n", " logger.info(f\" Instantaneous batch size per device = {args.train_batch_size}\")\n", " logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", " logger.info(f\" Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n", " logger.info(f\" Total optimization steps = {args.max_train_steps}\")\n", " # Only show the progress bar once on each machine.\n", " progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n", " progress_bar.set_description(\"Steps\")\n", " global_step = 0\n", "\n", " try:\n", " if accelerator.is_main_process:\n", " logger.info(\"using local safety checker\")\n", " safety_checker=StableDiffusionSafetyChecker.from_pretrained(args.pretrained_model_name_or_path,subfolder='safety_checker')\n", " feature_extractor=CLIPFeatureExtractor.from_pretrained(os.path.join(args.pretrained_model_name_or_path,'feature_extractor/preprocessor_config.json'))\n", " except Exception:\n", " if accelerator.is_main_process:\n", " logger.info(\"using hf download for safety checkers\")\n", " print(Exception)\n", " safety_checker=StableDiffusionSafetyChecker.from_pretrained(\"CompVis/stable-diffusion-safety-checker\")\n", " feature_extractor=CLIPFeatureExtractor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", " \n", " accelerator.wait_for_everyone()\n", " \n", " for epoch in range(args.num_train_epochs):\n", " text_encoder.train()\n", " for step, batch in enumerate(train_dataloader):\n", " \n", " with accelerator.accumulate(unet):\n", " # Convert images to latent space\n", " latents = vae.encode(batch[\"pixel_values\"]).latent_dist.sample().detach()\n", " latents = latents * 0.18215\n", "\n", " # Sample noise that we'll add to the latents\n", " noise = torch.randn(latents.shape).to(latents.device)\n", " bsz = latents.shape[0]\n", " # Sample a random timestep for each image\n", " timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()\n", "\n", " # Add noise to the latents according to the noise magnitude at each timestep\n", " # (this is the forward diffusion process)\n", " noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n", "\n", " # Get the text embedding for conditioning\n", " encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n", "\n", " # Predict the noise residual and compute loss\n", " noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)[\"sample\"]\n", "\n", " loss = F.mse_loss(noise_pred, noise, reduction=\"none\")\n", " loss = loss.mean([1, 2, 3]).mean()\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", "\n", "\n", " # Checks if the accelerator has performed an optimization step behind the scenes\n", " if accelerator.sync_gradients :\n", " for _ in range(accelerator.num_processes):\n", " progress_bar.update(1)\n", " global_step += 1 \n", "\n", " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", "\n", " if global_step >= args.max_train_steps:\n", " break\n", "\n", " accelerator.wait_for_everyone()\n", "\n", " # Create the pipeline using the trained modules and save it.\n", " if accelerator.is_main_process:\n", " pipeline = StableDiffusionPipeline(\n", " text_encoder=accelerator.unwrap_model(text_encoder),\n", " vae=vae,\n", " unet=unet.module if accelerator.num_processes >1 else unet,\n", " tokenizer=tokenizer,\n", " scheduler=PNDMScheduler(\n", " beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n", " ),\n", " safety_checker=safety_checker,\n", " feature_extractor=feature_extractor,\n", " )\n", " pipeline.save_pretrained(args.output_dir)\n", "\n", " if args.push_to_hub:\n", " repo.push_to_hub(\n", " args, pipeline, repo, commit_message=\"End of training\", blocking=False, auto_lfs_prune=True\n", " )\n", "\n", " accelerator.end_training()\n", " \n", "def _mp_fn(index):\n", " main()\n", "\n", "def read_index(index_path, n_rows = 10000000):\n", "\n", " data = []\n", "\n", " count = 0 \n", " \n", " with open(index_path) as f:\n", " \n", " for row in f.readlines():\n", " \n", " try:\n", " j = json.loads(row.strip())\n", "\n", " if '.jpg' in j['image']:\n", "\n", " # only keep valid image pointers \n", "\n", " data.append(j)\n", " count += 1 \n", " if count >= n_rows:\n", " return data\n", " except:\n", " continue\n", " \n", " return data\n", "\n", "def load_index_dataset(train_data_dir, n_rows):\n", " \n", " index_path = train_data_dir + '/data_index.jsonl'\n", " \n", " print ('reading the index from: {}'.format(index_path))\n", "\n", " data = read_index(index_path, n_rows)\n", " \n", " print ('read {} objects from index path'.format(len(data)))\n", "\n", " df = pd.DataFrame.from_records(data)\n", " \n", " print ('pandas df has shape of {}'.format(df.shape))\n", " \n", " dataset = Dataset.from_pandas(df)\n", " \n", " rt = DatasetDict({'train':dataset})\n", " \n", " return rt \n", "\n", "if __name__ == \"__main__\": \n", "\n", " main()\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "0c9ceb88-c1b1-4f57-ada4-3fd9a37741bf", "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker.huggingface import HuggingFace\n", "\n", "sess = sagemaker.Session()\n", "\n", "role = sagemaker.get_execution_role()\n", "\n", "bucket = sess.default_bucket()\n", "\n", "version = 'v1'\n", "\n", "image_uri = '220691188711.dkr.ecr.us-east-1.amazonaws.com/stable-diffusion:{}'.format(version )\n", "\n", "# required in this version of the train script\n", "data_channels['sd_base_model'] = 's3://dist-train/stable-diffusion/conceptual_captions/sd-base-model/'\n", "\n", "hyperparameters={'pretrained_model_name_or_path':'/opt/ml/input/data/sd_base_model',\n", " 'train_data_dir':'/opt/ml/input/data/training/laion-fsx',\n", " 'index_name':'data_index.jsonl',\n", " 'caption_column':'caption',\n", " 'image_column':'image',\n", " 'resolution':256,\n", " 'mixed_precision':'fp16',\n", " # this is per device\n", " 'train_batch_size':22,\n", " 'learning_rate': '1e-10',\n", " # 'max_train_steps':1000000,\n", " 'num_train_epochs':1,\n", " 'output_dir':'/opt/ml/model/sd-output-final', \n", " 'n_rows':50000000}\n", "\n", "est = HuggingFace(entry_point='finetune.py',\n", " source_dir='stable_scripts',\n", " image_uri=image_uri,\n", " sagemaker_session=sess,\n", " role=role,\n", " output_path=\"s3://{}/output/model/\".format(bucket), \n", " instance_type='ml.p4dn.24xlarge',\n", " keep_alive_period_in_seconds = 60*60,\n", " py_version='py38',\n", " base_job_name='fsx-stable-diffusion', \n", " instance_count=24,\n", " enable_network_isolation=True,\n", " encrypt_inter_container_traffic = True,\n", " # all opt/ml paths point to SageMaker training \n", " hyperparameters = hyperparameters,\n", " distribution={\"smdistributed\": { \"dataparallel\": { \"enabled\": True } }},\n", " max_retry_attempts = 30,\n", " max_run = 4 * 60 * 60,\n", " debugger_hook_config=False,\n", " disable_profiler = True,\n", " **kwargs)\n", "\n", "est.fit(inputs=data_channels, wait=False)" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }