import argparse from contextlib import contextmanager import torch def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset)." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") parser.add_argument( "--validation_data_dir", type=str, default=None, help="A folder containing the validation data." ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing an image." ) parser.add_argument( "--caption_column", type=str, default="text", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--max_eval_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ), ) parser.add_argument( "--train_val_split", type=float, default=0.15, help="Percent to split off of train for validation", ) parser.add_argument( "--output_dir", type=str, default="sd-model-finetuned", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution (if not set, use random crop)", ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", type=int, default=-1, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument( "--use_auth_token", action="store_true", help=( "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" " private models)." ), ) parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument("--backend",type=str,default="nccl") parser.add_argument("--datatype",type=str,default="fp16") args,_ = parser.parse_known_args() # Sanity checks if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None: raise ValueError("Need either a dataset name or a training/validation folder.") return args # Distributed training helper methods. def wait_for_everyone(): torch.distributed.barrier() def is_main_process(rank): if rank == 0: return True else: return False def _goes_first(is_main): if not is_main: wait_for_everyone() yield if is_main: wait_for_everyone() @contextmanager def main_process_first(rank): """ Lets the main process go first inside a with block. The other processes will enter the with block after the main process exits. """ yield from _goes_first(is_main_process(rank)) def is_local_main_process(local_rank): return local_rank == 0