# Train Stable Diffusion on SageMaker distributed training
For this notebook to work properly, you'll need to make sure that FSx for Lustre has actually inherited all of your available data. To do this, open the Lustre console view, select the association id that points to your preferred S3 path, then click "actions," "create an import task." This will start an action in Lustre to import all of the data in your S3 path onto Lustre.

Specifically this script is expecting multiple folders named "part-{}".

### Step 1. Point to FSx for Lustre

In [2]:
from sagemaker.inputs import FileSystemInput

# Specify FSx Lustre file system id.
file_system_id = "fs-0a83907c9c9c7b8f0"

# Specify the SG and subnet used by the FSX, these are passed to SM Estimator so jobs use this as well
fsx_security_group_id = "sg-ac4f1cb5"
fsx_subnet = "subnet-be054be1"

# Specify directory path for input data on the file system.
# You need to provide normalized and absolute path below.
# Your mount name can be provided by you when creating fsx, or generated automatically.
# You can find this mount_name on the FSX page in console.
# Example of fsx generated mount_name: "3x5lhbmv"
base_path = "/yflftbev"

# Specify your file system type.
file_system_type = "FSxLustre"

train = FileSystemInput(
 file_system_id=file_system_id,
 file_system_type=file_system_type,
 directory_path=base_path,
 file_system_access_mode="rw",
)

data_channels = {"train": train}

In [3]:
kwargs = {}
# Use the security group and subnet that was used to create the fsx filesystem
kwargs["security_group_ids"] = [fsx_security_group_id]
kwargs["subnets"] = [fsx_subnet]

### Step 2. Process data and build a json index
In my implementation of this, I actually built my own data loader function that used a custom json lines file. This saved a lot of time in loading the data, because rather than needing to `ls` all of my files, I simply had them predefined. You might think that's not a big deal, but once you're looking at more than a few million image/text pairs, it adds up!

Details on my full case study are [available here](https://medium.com/@emilywebber/how-i-trained-10tb-for-stable-diffusion-on-sagemaker-39dcea49ce32).

In [4]:
!mkdir stable_scripts

In [34]:
%%writefile stable_scripts/process_data.py

import argparse
import math
import os
import random
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from os import listdir
import os
from skimage import io

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from torchvision.io import ImageReadMode, read_image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import requests
from PIL import Image
# from mpi4py import MPI

from datasets import load_dataset 
from datasets import Dataset
from datasets import DatasetDict

import glob
import multiprocessing as mp
from multiprocessing import Pool
import pandas as pd
import json

def parse_args():
 
 parser = argparse.ArgumentParser(description="Simple example of a training script.")

 parser.add_argument("--train_data_dir", type=str, default=os.environ['SM_HP_TRAIN_DATA_DIR'], help="A folder containing the training data.")
 
 parser.add_argument("--function", type=str, default=os.environ['SM_HP_FUNCTION'], help="A generic argument to determine function of this script. Could be unzip and/or build pointer")
 
 parser.add_argument("--model_dir", type=str, default=os.environ['SM_MODEL_DIR'], help="SM training path for model, will copy to S3 after job completes")
 
 parser.add_argument("--index_name", type=str, default=os.environ['SM_HP_INDEX_NAME'], help="To point to the name of the index file")
 
 args = parser.parse_args()
 
 return args

def read_caption(path_to_image):
 '''
 Takes a full path and full object number, returns the string content
 '''
 path_to_text = path_to_image.replace('jpg', 'txt')
 
 with open(path_to_text) as f:
 data = f.readlines()
 return data
 
def unzip_part(full_part_path):
 '''
 Adds all images for one part to the dataset dictionary list
 ''' 
 
 print ('Working on part num: {}'.format(full_part_path))
 
 img_list = glob.glob("{}/*.jpg".format(full_part_path))

 print ('This part now has {} images!'.format(len(img_list)))

 unzip = True

 if unzip:
 
 # look for all tar balls in this part path
 tar_balls = glob.glob("{}/*.tar".format(full_part_path))
 
 # this would be the place to try and add multiprocessing 
 for tball in tar_balls:

 # unzip the folder there, checks per file if already unzipped
 cmd = 'tar -xf {} --skip-old-files --directory {}'.format(tball, full_part_path)
 os.system(cmd)
 
def write_index(full_part_path):
 '''
 Takes one full part, loops through it, grabs each image/text pair, writes them to a json lines file.
 '''

 print ('Working on part num: {}'.format(full_part_path))

 img_list = glob.glob("{}/*.jpg".format(full_part_path))

 print ('This part now has {} images!'.format(len(img_list)))

 index_path = full_part_path.split('part')[0] + 'data_index.jsonl'

 print ('Writing index to {}'.format(index_path))

 with open(index_path, 'a') as fp:

 for path_to_image in img_list:

 try:
 caption = read_caption(path_to_image)
 pair = {"image":path_to_image, "caption": caption[0]} 
 json.dump(pair, fp)
 fp.write('\n')
 
 except:
 continue
 
if __name__ == "__main__": 

 args = parse_args()
 
 print ('Train data dir is here: {}'.format(args.train_data_dir))
 
 part_list = glob.glob("{}/part-*".format(args.train_data_dir)) 
 
 print ('Found {} parts to work on, starting multiprocessing pool'.format(len(part_list)))
 
 cpus = mp.cpu_count()
 
 with Pool(cpus) as p:
 
 if 'unzip' in args.function:
 p.map(unzip_part, part_list)

 if 'index' in args.function:
 p.map(write_index, part_list)
 
 cmd = 'cp {}/data_index.jsonl {}'.format(args.train_data_dir, args.model_dir)

 os.system(cmd)

Overwriting stable_scripts/process_data.py


#### Now let's run that on SageMaker training

In [35]:
version = 'v1'

# points to an image I've made and am hosting for you to use
image_uri = '220691188711.dkr.ecr.us-east-1.amazonaws.com/stable-diffusion:{}'.format(version )

In [39]:
import sagemaker
from sagemaker.pytorch import PyTorch

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

bucket = sess.default_bucket()

hyperparameters = {'train_data_dir':'/opt/ml/input/data/train/fsx-data',
 'function':'unzip,index', 
 'index_name': 'data_index.jsonl'}

estimator = PyTorch(
 entry_point="process_data.py",
 base_job_name="stable-diffusion-process-data",
 role=role,
 image_uri = image_uri,
 source_dir="stable_scripts",
 # configures the SageMaker training resource, you can increase as you need
 instance_count=1,
 instance_type="ml.c5n.18xlarge",
 py_version="py38",
 framework_version = '1.10',
 sagemaker_session=sess,
 hyperparameters = hyperparameters,
 debugger_hook_config=False,
 # enable warm pools for 60 minutes, useful for debugging
 keep_alive_period_in_seconds = 60 * 60,
 **kwargs
)

In [40]:
estimator.fit(inputs = data_channels, wait=False)

### Step 3. Test the index and data loader locally

In [None]:
!aws s3 cp 

In [None]:
# make sure this works locally. if not, you'll waste a ton of GPU time. 
def load_index(args):
 
 print ('loading the index')
 
 index_path = args.train_data_dir + '/' + args.index_name

 print ('pointing to', index_path)
 
 df = pd.read_json(index_path, lines=True)
 
 print ('read the dataframe, shape like', df.shape)
 
 dataset = Dataset.from_pandas(df)
 
 rt = DatasetDict({'train':dataset})
 
 print ('read the DatasetDict, columns like', rt.column_names)
 
 return rt 

### Step 4. Run the full job on SageMaker

In [33]:
%%writefile stable_scripts/finetune.py
import argparse
import math
import os
import random
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from os import listdir
import os
from skimage import io

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset

from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel

from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from torchvision.io import ImageReadMode, read_image
from tqdm.auto import tqdm
import requests
from PIL import Image
from mpi4py import MPI

from datasets import load_dataset, Dataset, DatasetDict 
import json
import glob
import multiprocessing as mp
from multiprocessing import Pool
import pandas as pd

logger = get_logger(__name__)

def parse_args():
 parser = argparse.ArgumentParser(description="Simple example of a training script.")
 parser.add_argument(
 "--pretrained_model_name_or_path",
 type=str,
 default=None,
 required=True,
 help="Path to pretrained model or model identifier from huggingface.co/models.",
 )
 parser.add_argument(
 "--dataset_name",
 type=str,
 default=None,
 help=(
 "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
 " dataset)."
 ),
 )
 parser.add_argument(
 "--dataset_config_name",
 type=str,
 default=None,
 help="The config of the Dataset, leave as None if there's only one config.",
 )
 parser.add_argument("--train_data_dir", type=str, default=os.environ['SM_CHANNEL_TRAINING'], help="A folder containing the training data.")
 
 parser.add_argument(
 "--validation_data_dir", type=str, default=None, help="A folder containing the validation data."
 )
 parser.add_argument(
 "--image_column", type=str, default = os.environ["SM_HP_IMAGE_COLUMN"], help="The column of the dataset containing an image."
 )
 parser.add_argument(
 "--caption_column",
 type=str,
 default="text",
 help="The column of the dataset containing a caption or a list of captions.",
 )
 parser.add_argument(
 "--max_train_samples",
 type=int,
 default=None,
 help=(
 "For debugging purposes or quicker training, truncate the number of training examples to this "
 "value if set."
 ),
 )
 parser.add_argument(
 "--max_eval_samples",
 type=int,
 default=None,
 help=(
 "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
 "value if set."
 ),
 )
 parser.add_argument(
 "--train_val_split",
 type=float,
 default=0.15,
 help="Percent to split off of train for validation",
 )
 parser.add_argument(
 "--output_dir",
 type=str,
 default="sd-model-finetuned",
 help="The output directory where the model predictions and checkpoints will be written.",
 )
 parser.add_argument(
 "--cache_dir",
 type=str,
 default=None,
 help="The directory where the downloaded models and datasets will be stored.",
 )
 parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
 parser.add_argument(
 "--resolution",
 type=int,
 default=512,
 help=(
 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
 " resolution"
 ),
 )
 parser.add_argument(
 "--center_crop",
 action="store_true",
 help="Whether to center crop images before resizing to resolution (if not set, use random crop)",
 )
 parser.add_argument(
 "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
 )
 parser.add_argument(
 "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader."
 )
 parser.add_argument("--num_train_epochs", type=int, default=100)
 parser.add_argument(
 "--max_train_steps",
 type=int,
 default=-1,
 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
 )
 parser.add_argument(
 "--gradient_accumulation_steps",
 type=int,
 default=1,
 help="Number of updates steps to accumulate before performing a backward/update pass.",
 )
 parser.add_argument(
 "--learning_rate",
 type=float,
 default=1e-4,
 help="Initial learning rate (after the potential warmup period) to use.",
 )
 parser.add_argument(
 "--scale_lr",
 action="store_true",
 default=True,
 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
 )
 parser.add_argument(
 "--lr_scheduler",
 type=str,
 default="constant",
 help=(
 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
 ' "constant", "constant_with_warmup"]'
 ),
 )
 parser.add_argument(
 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
 )
 parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
 parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
 parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
 parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
 parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
 parser.add_argument(
 "--use_auth_token",
 action="store_true",
 help=(
 "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
 " private models)."
 ),
 )
 parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
 parser.add_argument(
 "--hub_model_id",
 type=str,
 default=None,
 help="The name of the repository to keep in sync with the local `output_dir`.",
 )
 parser.add_argument(
 "--logging_dir",
 type=str,
 default="logs",
 help=(
 "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
 " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
 ),
 )
 parser.add_argument(
 "--mixed_precision",
 type=str,
 default="no",
 choices=["no", "fp16", "bf16"],
 help=(
 "Whether to use mixed precision. Choose"
 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
 "and an Nvidia Ampere GPU."
 ),
 )
 parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

 parser.add_argument("--index_name", type=str, default=os.environ['SM_HP_INDEX_NAME'], help="To point to the name of the index file on FSx for Lustre")
 
 parser.add_argument("--n_rows", type=int, default=os.environ['SM_HP_N_ROWS'], help="Defines the number of rows to read from the index file")
 
 args = parser.parse_args()
 
 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
 if int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",-1)) >= 0:
 env_local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK"))
 os.environ['LOCAL_RANK'] = str(env_local_rank)
 args.local_rank = env_local_rank
 os.environ['RANK'] = os.environ.get("OMPI_COMM_WORLD_RANK")
 os.environ['WORLD_SIZE'] = os.environ.get("OMPI_COMM_WORLD_SIZE")

 # Sanity checks
 if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None:
 raise ValueError("Need either a dataset name or a training/validation folder.")

 return args

def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
 if token is None:
 token = HfFolder.get_token()
 if organization is None:
 username = whoami(token)["name"]
 return f"{username}/{model_id}"
 else:
 return f"{organization}/{model_id}"


def freeze_params(params):
 for param in params:
 param.requires_grad = False

dataset_name_mapping = {
 "image_caption_dataset.py": ("image_path", "caption"),
}
 
def main():
 
 args = parse_args()
 
 logging_dir = os.path.join(args.output_dir, args.logging_dir)

 accelerator = Accelerator(
 gradient_accumulation_steps=args.gradient_accumulation_steps,
 mixed_precision=args.mixed_precision
 )

 # If passed along, set the training seed now.
 if args.seed is not None:
 set_seed(args.seed)

 # Handle the repository creation
 if accelerator.is_main_process:
 
 if args.push_to_hub:
 if args.hub_model_id is None:
 repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
 else:
 repo_name = args.hub_model_id
 repo = Repository(args.output_dir, clone_from=repo_name)

 with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
 if "step_*" not in gitignore:
 gitignore.write("step_*\n")
 if "epoch_*" not in gitignore:
 gitignore.write("epoch_*\n")
 elif args.output_dir is not None:
 os.makedirs(args.output_dir, exist_ok=True)

 # Load models and create wrapper for stable diffusion
 tokenizer = CLIPTokenizer.from_pretrained(
 args.pretrained_model_name_or_path,
 subfolder="tokenizer",
 use_auth_token=args.use_auth_token,
 )
 text_encoder = CLIPTextModel.from_pretrained(
 args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token
 )
 vae = AutoencoderKL.from_pretrained(
 args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token
 )
 unet = UNet2DConditionModel.from_pretrained(
 args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token
 )

 # Freeze vae and text_encoder
 freeze_params(vae.parameters())
 freeze_params(text_encoder.parameters())

 if args.scale_lr:
 args.learning_rate = (
 args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
 )

 # Initialize the optimizer
 optimizer = torch.optim.AdamW(
 unet.parameters(),
 lr=args.learning_rate,
 betas=(args.adam_beta1, args.adam_beta2),
 weight_decay=args.adam_weight_decay,
 eps=args.adam_epsilon,
 )

 # TODO (patil-suraj): load scheduler using args
 noise_scheduler = DDPMScheduler(
 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
 )

 # Get the datasets: you can either provide your own training and evaluation files (see below)
 # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

 # In distributed training, the load_dataset function guarantees that only one local process can concurrently
 # download the dataset.
 if (args.dataset_name is not None) and ('.' not in args.dataset_name):
 # Downloading and loading a dataset from the hub.
 dataset = load_dataset(
 args.dataset_name,
 args.dataset_config_name,
 cache_dir=args.cache_dir,
 use_auth_token=True if args.use_auth_token else None,
 )
 elif (args.dataset_name is not None):
 dataset = load_dataset('parquet',data_files=args.dataset_name)
 
 else:
 data_files = {}
 if args.train_data_dir is not None:
 data_files["train"] = os.path.join(args.train_data_dir, "**")
 if args.validation_data_dir is not None:
 data_files["validation"] = os.path.join(args.validation_data_dir, "**") 

 train_transforms = transforms.Compose(
 [
 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
 transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
 transforms.ToTensor(),
 transforms.Normalize([0.5], [0.5]),
 ]
 )
 val_transforms = transforms.Compose(
 [
 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
 transforms.CenterCrop(args.resolution),
 transforms.ToTensor(),
 transforms.Normalize([0.5], [0.5]),
 ]
 )

 # this function expects the image path 
 def preprocess_train(examples):
 images = [Image.open(image).convert("RGB") for image in examples[image_column]]
 examples["pixel_values"] = [train_transforms(image) for image in images]
 examples["input_ids"] = tokenize_captions(examples)

 return examples

 def preprocess_val(examples):
 images = [Image.open(image).convert("RGB") for image in examples[image_column]]
 examples["pixel_values"] = [val_transforms(image) for image in images]
 examples["input_ids"] = tokenize_captions(examples, is_train=False)
 return examples

 with accelerator.main_process_first():
 
 print ('triggered main function')
 
 dataset = load_index_dataset(args.train_data_dir, args.n_rows)

 # If we don't have a validation split, split off a percentage of train as validation.
 args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split

 if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
 split = dataset["train"].train_test_split(args.train_val_split)
 dataset["train"] = split["train"]
 dataset["validation"] = split["test"]

 # Preprocessing the datasets.
 
 # We need to tokenize inputs and targets.
 column_names = dataset["train"].column_names

 # 6. Get the column names for input/target.
 dataset_columns = dataset_name_mapping.get(args.dataset_name, None)

 if args.image_column is None:
 image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
 else:
 image_column = args.image_column
 if image_column not in column_names:
 raise ValueError(
 f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
 )
 if args.caption_column is None:
 caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
 else:
 caption_column = args.caption_column
 if caption_column not in column_names:
 raise ValueError(
 f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
 )

 # Preprocessing the datasets.
 # We need to tokenize input captions and transform the images.
 def tokenize_captions(examples, is_train=True):
 captions = []
 for caption in examples[caption_column]:
 if isinstance(caption, str):
 captions.append(caption)
 elif isinstance(caption, (list, np.ndarray)):
 # take a random caption if there are multiple
 captions.append(random.choice(caption) if is_train else caption[0])
 else:
 raise ValueError(
 f"Caption column `{caption_column}` should contain either strings or lists of strings."
 )
 input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids
 return input_ids
 
 if args.max_train_samples is not None:
 dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
 
 
 # Set the training transforms
 train_dataset = dataset["train"].with_transform(preprocess_train)
 
 
 if args.max_eval_samples is not None:
 dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
 # Set the validation transforms
 eval_dataset = dataset["validation"].with_transform(preprocess_val)

 def collate_fn(examples):
 pixel_values = torch.stack([example["pixel_values"] for example in examples])
 input_ids = [example["input_ids"] for example in examples]
 padded_tokens = tokenizer.pad(
 {"input_ids": input_ids},
 padding="max_length",
 max_length=tokenizer.model_max_length,
 return_tensors="pt",
 )
 return {
 "pixel_values": pixel_values,
 "input_ids": padded_tokens.input_ids,
 "attention_mask": padded_tokens.attention_mask,
 }

 
 train_dataloader = torch.utils.data.DataLoader(
 train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=4
 )
 eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size, num_workers=4)

 # Scheduler and math around the number of training steps.
 overrode_max_train_steps = False
 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / (args.gradient_accumulation_steps))
 if args.max_train_steps <= 0:
 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
 overrode_max_train_steps = True

 lr_scheduler = get_scheduler(
 args.lr_scheduler,
 optimizer=optimizer,
 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
 )

 
 unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 unet, optimizer, train_dataloader, lr_scheduler
 )

 # Move vae and unet to device
 vae.to(accelerator.device)
 text_encoder.to(accelerator.device)

 # Keep vae and unet in eval model as we don't train these
 vae.eval()
 text_encoder.eval()
 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 if overrode_max_train_steps:
 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
 # Afterwards we recalculate our number of training epochs
 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

 # We need to initialize the trackers we use, and also store our configuration.
 # The trackers initializes automatically on the main process.
 if accelerator.is_main_process:
 accelerator.init_trackers("text2image-fine-tune", config=vars(args))

 # Train!
 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
 if accelerator.is_main_process:
 logger.info("***** Running training *****")
 logger.info(f" Num examples = {len(train_dataset)}")
 logger.info(f" Num Epochs = {args.num_train_epochs}")
 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 logger.info(f" Total optimization steps = {args.max_train_steps}")
 # Only show the progress bar once on each machine.
 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
 progress_bar.set_description("Steps")
 global_step = 0

 try:
 if accelerator.is_main_process:
 logger.info("using local safety checker")
 safety_checker=StableDiffusionSafetyChecker.from_pretrained(args.pretrained_model_name_or_path,subfolder='safety_checker')
 feature_extractor=CLIPFeatureExtractor.from_pretrained(os.path.join(args.pretrained_model_name_or_path,'feature_extractor/preprocessor_config.json'))
 except Exception:
 if accelerator.is_main_process:
 logger.info("using hf download for safety checkers")
 print(Exception)
 safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
 
 accelerator.wait_for_everyone()
 
 for epoch in range(args.num_train_epochs):
 text_encoder.train()
 for step, batch in enumerate(train_dataloader):
 
 with accelerator.accumulate(unet):
 # Convert images to latent space
 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
 latents = latents * 0.18215

 # Sample noise that we'll add to the latents
 noise = torch.randn(latents.shape).to(latents.device)
 bsz = latents.shape[0]
 # Sample a random timestep for each image
 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()

 # Add noise to the latents according to the noise magnitude at each timestep
 # (this is the forward diffusion process)
 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

 # Get the text embedding for conditioning
 encoder_hidden_states = text_encoder(batch["input_ids"])[0]

 # Predict the noise residual and compute loss
 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]

 loss = F.mse_loss(noise_pred, noise, reduction="none")
 loss = loss.mean([1, 2, 3]).mean()
 accelerator.backward(loss)
 optimizer.step()
 lr_scheduler.step()
 optimizer.zero_grad()



 # Checks if the accelerator has performed an optimization step behind the scenes
 if accelerator.sync_gradients :
 for _ in range(accelerator.num_processes):
 progress_bar.update(1)
 global_step += 1 

 logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
 progress_bar.set_postfix(**logs)
 accelerator.log(logs, step=global_step)

 if global_step >= args.max_train_steps:
 break

 accelerator.wait_for_everyone()

 # Create the pipeline using the trained modules and save it.
 if accelerator.is_main_process:
 pipeline = StableDiffusionPipeline(
 text_encoder=accelerator.unwrap_model(text_encoder),
 vae=vae,
 unet=unet.module if accelerator.num_processes >1 else unet,
 tokenizer=tokenizer,
 scheduler=PNDMScheduler(
 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
 ),
 safety_checker=safety_checker,
 feature_extractor=feature_extractor,
 )
 pipeline.save_pretrained(args.output_dir)

 if args.push_to_hub:
 repo.push_to_hub(
 args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
 )

 accelerator.end_training()
 
def _mp_fn(index):
 main()

def read_index(index_path, n_rows = 10000000):

 data = []

 count = 0 
 
 with open(index_path) as f:
 
 for row in f.readlines():
 
 try:
 j = json.loads(row.strip())

 if '.jpg' in j['image']:

 # only keep valid image pointers 

 data.append(j)
 count += 1 
 if count >= n_rows:
 return data
 except:
 continue
 
 return data

def load_index_dataset(train_data_dir, n_rows):
 
 index_path = train_data_dir + '/data_index.jsonl'
 
 print ('reading the index from: {}'.format(index_path))

 data = read_index(index_path, n_rows)
 
 print ('read {} objects from index path'.format(len(data)))

 df = pd.DataFrame.from_records(data)
 
 print ('pandas df has shape of {}'.format(df.shape))
 
 dataset = Dataset.from_pandas(df)
 
 rt = DatasetDict({'train':dataset})
 
 return rt 

if __name__ == "__main__": 

 main()
 

Overwriting stable_scripts/finetune.py


In [None]:
import sagemaker
from sagemaker.huggingface import HuggingFace

sess = sagemaker.Session()

role = sagemaker.get_execution_role()

bucket = sess.default_bucket()

version = 'v1'

image_uri = '220691188711.dkr.ecr.us-east-1.amazonaws.com/stable-diffusion:{}'.format(version )

# required in this version of the train script
data_channels['sd_base_model'] = 's3://dist-train/stable-diffusion/conceptual_captions/sd-base-model/'

hyperparameters={'pretrained_model_name_or_path':'/opt/ml/input/data/sd_base_model',
 'train_data_dir':'/opt/ml/input/data/training/laion-fsx',
 'index_name':'data_index.jsonl',
 'caption_column':'caption',
 'image_column':'image',
 'resolution':256,
 'mixed_precision':'fp16',
 # this is per device
 'train_batch_size':22,
 'learning_rate': '1e-10',
 # 'max_train_steps':1000000,
 'num_train_epochs':1,
 'output_dir':'/opt/ml/model/sd-output-final', 
 'n_rows':50000000}

est = HuggingFace(entry_point='finetune.py',
 source_dir='stable_scripts',
 image_uri=image_uri,
 sagemaker_session=sess,
 role=role,
 output_path="s3://{}/output/model/".format(bucket), 
 instance_type='ml.p4dn.24xlarge',
 keep_alive_period_in_seconds = 60*60,
 py_version='py38',
 base_job_name='fsx-stable-diffusion', 
 instance_count=24,
 enable_network_isolation=True,
 encrypt_inter_container_traffic = True,
 # all opt/ml paths point to SageMaker training 
 hyperparameters = hyperparameters,
 distribution={"smdistributed": { "dataparallel": { "enabled": True } }},
 max_retry_attempts = 30,
 max_run = 4 * 60 * 60,
 debugger_hook_config=False,
 disable_profiler = True,
 **kwargs)

est.fit(inputs=data_channels, wait=False)