import os import sys import argparse import logging import diffusers logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) def main(): logger.info("Training starts...") parser = argparse.ArgumentParser() parser.add_argument("--dataset_name", type=str, default="huggan/flowers-102-categories") parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--output_dir", type=str, default="/opt/ml/model") parser.add_argument("--train_batch_size", type=int, default=4) parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) args = parser.parse_args() diffusers_version = diffusers.__version__ # download unconditional training script from diffusers branch = "v" + diffusers_version task = "unconditional_image_generation" script = "train_unconditional.py" url = ( f"https://raw.githubusercontent.com/huggingface/diffusers/{branch}/examples/{task}/{script}" ) os.system(f"wget {url}") # create default accelerate config os.system("accelerate config default") # run accelerate command accelerate_cmd = ( f"accelerate launch {script} --dataset_name={args.dataset_name}" f" --resolution={args.resolution} --output_dir={args.output_dir}" f" --train_batch_size={args.train_batch_size} --num_epochs={args.num_epochs}" f" --gradient_accumulation_steps={args.gradient_accumulation_steps}" ) logger.info(f"Calling {accelerate_cmd}") os.system(accelerate_cmd) if __name__ == "__main__": main()