import argparse import math import os import random from pathlib import Path from typing import Optional import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from torchvision import transforms from torchvision.io import ImageReadMode, read_image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer import requests from PIL import Image logger = get_logger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset)." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") parser.add_argument( "--validation_data_dir", type=str, default=None, help="A folder containing the validation data." ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing an image." ) parser.add_argument( "--caption_column", type=str, default="text", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--max_eval_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ), ) parser.add_argument( "--train_val_split", type=float, default=0.15, help="Percent to split off of train for validation", ) parser.add_argument( "--output_dir", type=str, default="sd-model-finetuned", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution (if not set, use random crop)", ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", type=int, default=-1, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument( "--use_auth_token", action="store_true", help=( "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" " private models)." ), ) parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",-1)) >= 0: env_local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) os.environ['LOCAL_RANK'] = str(env_local_rank) args.local_rank = env_local_rank os.environ['RANK'] = os.environ.get("OMPI_COMM_WORLD_RANK") os.environ['WORLD_SIZE'] = os.environ.get("OMPI_COMM_WORLD_SIZE") # Sanity checks if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None: raise ValueError("Need either a dataset name or a training/validation folder.") return args def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" def freeze_params(params): for param in params: param.requires_grad = False dataset_name_mapping = { "image_caption_dataset.py": ("image_path", "caption"), } def _mp_fn(index): main() def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: gitignore.write("step_*\n") if "epoch_*" not in gitignore: gitignore.write("epoch_*\n") elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load models and create wrapper for stable diffusion tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token, ) text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token ) # Freeze vae and text_encoder freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Initialize the optimizer optimizer = torch.optim.AdamW( unet.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" ) # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if (args.dataset_name is not None) and ('.' not in args.dataset_name): # Downloading and loading a dataset from the hub. dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, use_auth_token=True if args.use_auth_token else None, ) elif (args.dataset_name is not None): dataset = load_dataset('parquet',data_files=args.dataset_name) else: data_files = {} if args.train_data_dir is not None: data_files["train"] = os.path.join(args.train_data_dir, "**") if args.validation_data_dir is not None: data_files["validation"] = os.path.join(args.validation_data_dir, "**") dataset = load_dataset( "imagefolder", data_files=data_files, cache_dir=args.cache_dir, ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_process#imagefolder. # If we don't have a validation split, split off a percentage of train as validation. args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: split = dataset["train"].train_test_split(args.train_val_split) dataset["train"] = split["train"] dataset["validation"] = split["test"] # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. dataset_columns = dataset_name_mapping.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) if args.caption_column is None: caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] else: caption_column = args.caption_column if caption_column not in column_names: raise ValueError( f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" ) # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: if isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids return input_ids train_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) val_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def preprocess_train(examples): images = [Image.open(image).convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples) return examples def preprocess_val(examples): images = [Image.open(image).convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [val_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples, is_train=False) return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) if args.max_eval_samples is not None: dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) # Set the validation transforms eval_dataset = dataset["validation"].with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) input_ids = [example["input_ids"] for example in examples] padded_tokens = tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ) return { "pixel_values": pixel_values, "input_ids": padded_tokens.input_ids, "attention_mask": padded_tokens.attention_mask, } train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=4 ) eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size, num_workers=4) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / (args.gradient_accumulation_steps)) if args.max_train_steps <= 0: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) # Move vae and unet to device vae.to(accelerator.device) text_encoder.to(accelerator.device) # Keep vae and unet in eval model as we don't train these vae.eval() text_encoder.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune", config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps if accelerator.is_main_process: logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") global_step = 0 try: if accelerator.is_main_process: logger.info("using local safety checker") safety_checker=StableDiffusionSafetyChecker.from_pretrained(args.pretrained_model_name_or_path,subfolder='safety_checker') feature_extractor=CLIPFeatureExtractor.from_pretrained(os.path.join(args.pretrained_model_name_or_path,'feature_extractor/preprocessor_config.json')) except Exception: if accelerator.is_main_process: logger.info("using hf download for safety checkers") print(Exception) safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") accelerator.wait_for_everyone() for epoch in range(args.num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] loss = F.mse_loss(noise_pred, noise, reduction="none") loss = loss.mean([1, 2, 3]).mean() accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients : for _ in range(accelerator.num_processes): progress_bar.update(1) global_step += 1 logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break accelerator.wait_for_everyone() # Create the pipeline using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline( text_encoder=accelerator.unwrap_model(text_encoder), vae=vae, unet=unet.module if accelerator.num_processes >1 else unet, tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), safety_checker=safety_checker, feature_extractor=feature_extractor, ) pipeline.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub( args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True ) accelerator.end_training() if __name__ == "__main__": main()