{ "cells": [ { "cell_type": "markdown", "id": "450066fd", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Dreambooth Fine-tuning\n", "DreamBooth is a deep learning generation model used to fine-tune existing text-to-image models, developed by researchers from Google Research and Boston University in 2022. Originally developed using Google's own Imagen text-to-image model, DreamBooth implementations can be applied to other text-to-image models, where it can allow the model to generate more fine-tuned and personalised outputs after training on three to five images of a subject.\n", "\n", "We should use dreambooth fine tuning our stable diffusion model.\n", "\n", "#### Notebook step\n", "1. Import boto3, sagemaker python SDK\n", "2. Build dreambooth fine-tuning image\n", "3. Fine-tuning \n", " * config hyperparameter\n", " * create training job\n", "4. Testing " ] }, { "cell_type": "markdown", "id": "eb9eb077", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 1. Import boto3, sagemaker python SDK" ] }, { "cell_type": "code", "execution_count": null, "id": "8314fc9b-c468-497b-abcc-259ec792154c", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "from sagemaker.pytorch import PyTorch\n", "sagemaker_session = sagemaker.Session()\n", "bucket = sagemaker_session.default_bucket()\n", "role = sagemaker.get_execution_role()\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "region_name = boto3.session.Session().region_name\n", "\n", "images_s3uri = 's3://{0}/dreambooth/images/'.format(bucket)\n", "models_s3uri = 's3://{0}/stable-diffusion/models/'.format(bucket)\n", "dreambooth_s3uri = 's3://{0}/stable-diffusion/dreambooth/'.format(bucket)\n", "\n", "print(bucket)" ] }, { "cell_type": "markdown", "id": "bd2a3178", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2. Build dreambooth fine-tuning image \n", " It will take 60~90 minutes if use small notebook instance(ml.t3.xlarge)" ] }, { "cell_type": "code", "execution_count": null, "id": "a7612e5a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!./build_push.sh" ] }, { "cell_type": "markdown", "id": "1d843895", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 3. Fine-tuning \n", "\n", " * image_uri: docker image ecr URI\n", " * instance_type: training job instance , prefere ml.g4dn.xlarge, ml.g5.xlarge\n", " * class_prompt: class prompt\n", " * instance_prompt: your image key prompt\n", " * model_name: pretrained_model \n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "32ad6cd8-eece-43d2-b4c8-b210c63b7833", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "import json\n", "def json_encode_hyperparameters(hyperparameters):\n", " for (k, v) in hyperparameters.items():\n", " print(k, v)\n", " \n", " return {k: json.dumps(v) for (k, v) in hyperparameters.items()}\n", "\n", "\n", "\n", "\n", "image_uri = f'{account_id}.dkr.ecr.{region_name}.amazonaws.com/sd-dreambooth-finetuning-v2'\n", "instance_type = 'ml.g4dn.2xlarge'\n", "\n", "instance_prompt=\"photo\\ of\\ zwx\\ man\"\n", "class_prompt=\"photo\\ of\\ a\\ man\"\n", "s3_model_output_location='s3://{}/{}/{}'.format(bucket, 'dreambooth', 'trained_models')\n", "model_name=\"runwayml/stable-diffusion-v1-5\"\n", "instance_dir=\"/opt/ml/input/data/images/\"\n", "class_dir=\"/opt/ml/input/data/class_images/\"\n", "\n", "\n", "\n", "environment = {\n", " 'PYTORCH_CUDA_ALLOC_CONF':'max_split_size_mb:32',\n", " 'LD_LIBRARY_PATH':\"${LD_LIBRARY_PATH}:/opt/conda/lib/\"\n", "}\n", "\n", "hyperparameters = {\n", " 'model_name':'aws-trained-dreambooth-model',\n", " 'mixed_precision':'fp16',\n", " 'pretrained_model_name_or_path': model_name, \n", " 'instance_data_dir':instance_dir,\n", " 'class_data_dir':class_dir,\n", " 'with_prior_preservation':True,\n", " 'models_path': '/opt/ml/model/',\n", " 'instance_prompt': instance_prompt, \n", " 'class_prompt':class_prompt,\n", " 'resolution':512,\n", " 'train_batch_size':1,\n", " 'sample_batch_size': 1,\n", " 'gradient_accumulation_steps':1,\n", " 'learning_rate':2e-06,\n", " 'lr_scheduler':'constant',\n", " 'lr_warmup_steps':0,\n", " 'num_class_images':50,\n", " 'max_train_steps':300,\n", " 'save_steps':300,\n", " 'attention':'xformers',\n", " 'prior_loss_weight': 0.5,\n", " 'use_ema':True,\n", " 'train_text_encoder':False,\n", " 'not_cache_latents':True,\n", " 'gradient_checkpointing':True,\n", " 'save_use_epochs': False,\n", " 'use_8bit_adam': False\n", "}\n", "\n", "hyperparameters = json_encode_hyperparameters(hyperparameters)\n", "\n" ] }, { "cell_type": "markdown", "id": "9c569c81", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ " * Create training job " ] }, { "cell_type": "code", "execution_count": null, "id": "744ec9cb", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from sagemaker.estimator import Estimator\n", "inputs = {\n", " 'images': f\"s3://sagemaker-{region_name}-{account_id}/dreambooth/images/\"\n", "}\n", "\n", "\n", "estimator = Estimator(\n", " role = role,\n", " instance_count=1,\n", " instance_type = instance_type,\n", " image_uri = image_uri,\n", " hyperparameters = hyperparameters,\n", " environment = environment\n", ")\n", "estimator.fit(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "bd181b4e-f435-4dca-842a-444d083fdf3c", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "dreambooth_model_data = estimator.model_data\n", "print(\"Model artifact saved at:\\n\", dreambooth_model_data)" ] }, { "cell_type": "markdown", "id": "e3e21926", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 4. Testing \n", " you can use inference notebook load your new model" ] }, { "cell_type": "code", "execution_count": null, "id": "2696183b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "conda_pytorch_p39", "language": "python", "name": "conda_pytorch_p39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }