{ "cells": [ { "cell_type": "markdown", "id": "a0e92bd0", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Dreambooth 模型微调\n", "DreamBooth 是一种深度学习生成模型,用于微调现有的文本到图像模型,由 Google Research 和波士顿大学的研究人员于 2022 年开发。最初使用 Google 自己的 Imagen 文本到图像模型开发,DreamBooth 的实现可以应用到其他文本到图像模型,它可以让模型通过的三到五张图像对一个主题进行训练后生成更精细和个性化的输出。\n", "\n", "![](../../images/dreambooth.png)\n", "\n", "接下来我们将使用 DreamBooth 来微调我们的 stable diffusion 模型.\n", "\n", "#### Notebook 步骤\n", "1. 导入 boto3, sagemaker python SDK\n", "2. 构建 dreambooth fine-tuning 镜像\n", "3. 实现模型微调\n", " * 配置超参\n", " * 创建训练任务\n", "4. 测试" ] }, { "cell_type": "markdown", "id": "eb9eb077", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 1. 导入 boto3, sagemaker python SDK" ] }, { "cell_type": "code", "execution_count": null, "id": "8314fc9b-c468-497b-abcc-259ec792154c", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "from sagemaker.pytorch import PyTorch\n", "sagemaker_session = sagemaker.Session()\n", "bucket = sagemaker_session.default_bucket()\n", "role = sagemaker.get_execution_role()\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "region_name = boto3.session.Session().region_name\n", "\n", "images_s3uri = 's3://{0}/dreambooth/images/'.format(bucket)\n", "models_s3uri = 's3://{0}/stable-diffusion/models/'.format(bucket)\n", "dreambooth_s3uri = 's3://{0}/stable-diffusion/dreambooth/'.format(bucket)" ] }, { "cell_type": "markdown", "id": "bd2a3178", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2. 构建 dreambooth fine-tuning 镜像\n", " 如果你使用较小的实例,如 ml.t3.xlarge,从头构建需要 60~90 分钟的时间,为此workshop 提供了预构建docker image, 具体请参考Docker.public-ecr,修改代码train_dreambooth.py后需要重新运行" ] }, { "cell_type": "code", "execution_count": null, "id": "a7612e5a", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "!./build_push.sh" ] }, { "cell_type": "markdown", "id": "1d843895", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 3. 模型微调\n", "\n", " * image_uri: ecr仓库中的 docker 镜像地址\n", " * instance_type: 用于训练任务的实例大小 , 建议使用 ml.g4dn.xlarge, ml.g5.xlarge\n", " * class_prompt: 提示词类别\n", " * instance_prompt: 用于你的图片的关键词\n", " * model_name: 预训练的模型名称\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "32ad6cd8-eece-43d2-b4c8-b210c63b7833", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "import json\n", "def json_encode_hyperparameters(hyperparameters):\n", " for (k, v) in hyperparameters.items():\n", " print(k, v)\n", " \n", " return {k: json.dumps(v) for (k, v) in hyperparameters.items()}\n", "\n", "\n", "\n", "\n", "image_uri = f'{account_id}.dkr.ecr.{region_name}.amazonaws.com/sd-dreambooth-finetuning-v2'\n", "#instance_type = 'ml.g4dn.2xlarge'\n", "instance_type = 'local'\n", "\n", "instance_prompt=\"photo\\ of\\ zwx\\ man\"\n", "class_prompt=\"photo\\ of\\ a\\ man\"\n", "s3_model_output_location='s3://{}/{}/{}'.format(bucket, 'dreambooth', 'trained_models')\n", "model_name=\"runwayml/stable-diffusion-v1-5\"\n", "instance_dir=\"/opt/ml/input/data/images/\"\n", "class_dir=\"/opt/ml/input/data/class_images/\"\n", "\n", "\n", "\n", "environment = {\n", " 'PYTORCH_CUDA_ALLOC_CONF':'max_split_size_mb:32',\n", " 'LD_LIBRARY_PATH':\"${LD_LIBRARY_PATH}:/opt/conda/lib/\"\n", "}\n", "\n", "hyperparameters = {\n", " 'model_name':'aws-trained-dreambooth-model',\n", " 'mixed_precision':'fp16',\n", " 'pretrained_model_name_or_path': model_name, \n", " 'instance_data_dir':instance_dir,\n", " 'class_data_dir':class_dir,\n", " 'with_prior_preservation':True,\n", " 'models_path': '/opt/ml/model/',\n", " 'instance_prompt': instance_prompt, \n", " 'class_prompt':class_prompt,\n", " 'resolution':512,\n", " 'train_batch_size':1,\n", " 'sample_batch_size': 1,\n", " 'gradient_accumulation_steps':1,\n", " 'learning_rate':2e-06,\n", " 'lr_scheduler':'constant',\n", " 'lr_warmup_steps':0,\n", " 'num_class_images':50,\n", " 'max_train_steps':300,\n", " 'save_steps':100,\n", " 'attention':'xformers',\n", " 'prior_loss_weight': 0.5,\n", " 'use_ema':True,\n", " 'train_text_encoder':False,\n", " 'not_cache_latents':True,\n", " 'gradient_checkpointing':True,\n", " 'save_use_epochs': False,\n", " 'use_8bit_adam': False\n", "}\n", "\n", "hyperparameters = json_encode_hyperparameters(hyperparameters)\n", "\n" ] }, { "cell_type": "markdown", "id": "9c569c81", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ " * 创建训练任务" ] }, { "cell_type": "code", "execution_count": null, "id": "744ec9cb", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "from sagemaker.estimator import Estimator\n", "inputs = {\n", " 'images': f\"s3://{bucket}/dreambooth/images/\"\n", "}\n", "\n", "\n", "estimator = Estimator(\n", " role = role,\n", " instance_count=1,\n", " instance_type = instance_type,\n", " image_uri = image_uri,\n", " hyperparameters = hyperparameters,\n", " environment = environment\n", ")\n", "estimator.fit(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "bd181b4e-f435-4dca-842a-444d083fdf3c", "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "dreambooth_model_data = estimator.model_data\n", "print(\"Model artifact saved at:\\n\", dreambooth_model_data)" ] }, { "cell_type": "markdown", "id": "e3e21926", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 4. 测试\n", " 现在你可以使用推理笔记本加载您训练的模型" ] }, { "cell_type": "code", "execution_count": null, "id": "2696183b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#回到4.1 部署Stable Diffusion 模型 ,将andite/anything-v4.0替换为你fine tuning后的模型\n", "framework_version = '1.10'\n", "py_version = 'py38'\n", "\n", "model_environment = {\n", " 'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', \n", " 'SAGEMAKER_MODEL_SERVER_WORKERS': '1', \n", " #'model_name':'andite/anything-v4.0',\n", " 'model_name':'s3:///',\n", " 's3_bucket':bucket\n", "}" ] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "conda_pytorch_p39", "language": "python", "name": "conda_pytorch_p39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }