import math import os import random import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist import torch.utils.data.distributed from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from torch.distributed.distributed_c10d import ReduceOp from utils import parse_args,is_main_process,main_process_first,is_local_main_process,wait_for_everyone import urllib.request from PIL import Image def freeze_params(params): for param in params: param.requires_grad = False dataset_name_mapping = { "image_caption_dataset.py": ("image_path", "caption"), } def main(): args = parse_args() args.world_size = int(os.environ["WORLD_SIZE"]) args.local_rank = int(os.environ["LOCAL_RANK"]) args.global_rank = int(os.environ["RANK"]) print(f"local rank:{args.local_rank} global rank:{args.global_rank} world size:{args.world_size}") torch.cuda.set_device(args.local_rank) # initialize DDP with NCCL dist.init_process_group(backend=args.backend) # we will run the training with reduced precision to get better memory utilization. if args.datatype == "fp16": train_dtype = torch.float16 elif args.datatype == "bf16": train_dtype = torch.bfloat16 # If passed along, set the training seed now. if args.seed is not None: torch.manual_seed(args.seed) # Load models and create wrapper for stable diffusion tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token, ) text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token ) # Freeze vae and text_encoder as we will be fine tuning only the unet model. freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * args.world_size ) # Initialize the optimizer optimizer = torch.optim.AdamW( unet.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" ) # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if (args.dataset_name is not None) and ('.' not in args.dataset_name): # Downloading and loading a dataset from the hub. dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, use_auth_token=True if args.use_auth_token else None, ) elif (args.dataset_name is not None): dataset = load_dataset('parquet',data_files=args.dataset_name) else: data_files = {} if args.train_data_dir is not None: data_files["train"] = os.path.join(args.train_data_dir, "**") if args.validation_data_dir is not None: data_files["validation"] = os.path.join(args.validation_data_dir, "**") dataset = load_dataset( "imagefolder", data_files=data_files, cache_dir=args.cache_dir, ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_process#imagefolder. # If we don't have a validation split, split off a percentage of train as validation. args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: split = dataset["train"].train_test_split(args.train_val_split) dataset["train"] = split["train"] dataset["validation"] = split["test"] # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. dataset_columns = dataset_name_mapping.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) if args.caption_column is None: caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] else: caption_column = args.caption_column if caption_column not in column_names: raise ValueError( f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" ) # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: if isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids return input_ids train_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) val_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples) return examples def preprocess_val(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [val_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples, is_train=False) return examples with main_process_first(args.global_rank): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) if args.max_eval_samples is not None: dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) # Set the validation transforms eval_dataset = dataset["validation"].with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) input_ids = [example["input_ids"] for example in examples] padded_tokens = tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ) return { "pixel_values": pixel_values, "input_ids": padded_tokens.input_ids, "attention_mask": padded_tokens.attention_mask, } sampler = torch.utils.data.DistributedSampler( train_dataset, shuffle=True, seed=args.seed, rank=args.global_rank, num_replicas=args.world_size, drop_last=True, ) train_dataloader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=0 ) eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size, num_workers=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / (args.gradient_accumulation_steps)) if args.max_train_steps <= 0: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) device = torch.device("cuda") unet.to(device) # Wrap the model with DDP unet = DDP(unet,device_ids=[args.local_rank]) # Move vae and textencoder to device (gpu:rank) torch.cuda.set_device(args.local_rank) device = torch.device("cuda") vae.to(device,train_dtype) text_encoder.to(device,train_dtype) # Keep vae and unet in eval model as we don't train these vae.eval() text_encoder.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = args.train_batch_size * args.world_size * args.gradient_accumulation_steps if is_main_process(args.global_rank): print("***** Running training *****") print(f" Num examples = {len(train_dataset)}") print(f" Num Epochs = {args.num_train_epochs}") print(f" Instantaneous batch size per device = {args.train_batch_size}") print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") print(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not is_local_main_process(args.local_rank)) progress_bar.set_description("Steps") global_step = 0 try: if is_main_process(args.global_rank): print("using local safety checker") safety_checker=StableDiffusionSafetyChecker.from_pretrained(args.pretrained_model_name_or_path,subfolder='./safety_checker') feature_extractor=CLIPFeatureExtractor.from_pretrained(os.path.join(args.pretrained_model_name_or_path,'feature_extractor/preprocessor_config.json')) except Exception: if is_main_process(args.global_rank): print("using hf download for safety checkers") print(Exception) safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") wait_for_everyone() for epoch in range(args.num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): with torch.autocast(device_type='cuda', dtype=train_dtype): latents = vae.encode(batch["pixel_values"].to(device)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0] noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] loss = F.mse_loss(noise_pred, noise, reduction="none") loss.backward(loss) loss = loss.mean([1, 2, 3]).mean() # do all reduce of loss across ranks dist.all_reduce(loss, ReduceOp.SUM) loss = loss / args.world_size optimizer.zero_grad() optimizer.step() lr_scheduler.step() progress_bar.update(1) global_step += 1 print(f"step {global_step} loss {loss}") if global_step >= args.max_train_steps: break wait_for_everyone() # Create the pipeline using the trained modules and save it. if is_main_process(args.global_rank): pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, unet=unet.module if args.world_size >1 else unet, tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), safety_checker=safety_checker, feature_extractor=feature_extractor, ) pipeline.save_pretrained(args.output_dir) if __name__ == "__main__": main()