Note : Please use **PyTorch 2.0 Python 3.10 GPU kernel on g4dn.xlarge** for this notebook

# What's the use case?
In this lab we will generate synthetic satellite images. These images can be used for research or as input data for building your computer vision models.

# Stable Diffusion

## Why fine tune stable diffusion?

Although Stable diffusion is great at generating images, the quality of images that specialise in a particular are may not be great. For example, in this notebook we aim to generate satellite images. However, the default satellite images that are generated do show some of the features (such as highways) very well. To improve the quality of satellite images with highways, we fine-tune stable diffusion using real satellite images.

## How do we fine-tune

To fine-tune stable diffusion we use a method called DreamBooth which is described [here](https://dreambooth.github.io/). Here's a short description of dreambooth from the paper
> Our method takes as input a few images (typically 3-5 images suffice, based on our experiments) of a subject (e.g., a specific dog) and the corresponding class name (e.g. "dog"), and returns a fine-tuned/"personalized'' text-to-image model that encodes a unique identifier that refers to the subject. Then, at inference, we can implant the unique identifier in different sentences to synthesize the subjects in difference contexts.

**Lets Get started!**
The first step is to get a feel of the hardware. A reminder though, please make sure you have the right kernel and notebook size as specified at the top!



In [None]:
!nvidia-smi

Next, we install a few libraries that the notebook needs.

In [None]:
!pip install transformers accelerate>=0.16.0 ftfy tensorboard Jinja2 huggingface_hub wandb kaggle git+https://github.com/huggingface/diffusers

### Dataset
For this tutorial, we will use the EuroSAT dataset, which is a land use classification dataset consisting of Sentinel 2 Satellite images. We will use the `Highway` class as the type of satellite image that we would like to generate. The `Forest` and `Industrial` classes serve as the *class* that we want the model to separate the `Highway` *instance*. Note, for the purposes of this exercise, we will display all images resized to 64,64 to match the EuroSAT dataset image size.



In [None]:
!mkdir -p EuroSAT/Highway
!unzip -q eurosat-dataset.zip "EuroSAT/Highway/*" -d ""

In [None]:
!mkdir -p EuroSAT/base/Forest
!unzip -q eurosat-dataset.zip "EuroSAT/Forest/*" -d "base"

In [None]:
!mkdir -p EuroSAT/base/Industrial
!unzip -q eurosat-dataset.zip "EuroSAT/Industrial/*" -d "base"

## View Dataset
Let's view the `Highway` class of the EuroSAT dataset

In [None]:
from PIL import Image

def image_grid(imgs, rows, cols):
 assert len(imgs) == rows*cols

 w, h = imgs[0].size
 grid = Image.new('RGB', size=(cols*w, rows*h))
 grid_w, grid_h = grid.size
 
 for i, img in enumerate(imgs):
 grid.paste(img, box=(i%cols*w, i//cols*h))
 return grid

In [None]:
actual_img = [Image.open("EuroSAT/Highway/Highway_{}.jpg".format(str(i))) for i in range(1,11)]
image_grid([x.resize((64,64)) for x in actual_img], 2,5)

Let's view the `Forest` and `Industrial` classes:

In [None]:
actual_img = [Image.open("base/EuroSAT/Forest/Forest_{}.jpg".format(str(i))) for i in range(1,11)]
image_grid([x.resize((64,64)) for x in actual_img], 2,5)

In [None]:
actual_img = [Image.open("base/EuroSAT/Industrial/Industrial_{}.jpg".format(str(i))) for i in range(1,11)]
image_grid([x.resize((64,64)) for x in actual_img], 2,5)

In [None]:
import shutil, os
forest_files = os.listdir("base/EuroSAT/Forest")
industrial_files = os.listdir("base/EuroSAT/Industrial")

In [None]:
!mkdir -p "base/class"

Some preparatory stuff. Copy the files to a location that we can use during fine-tuning

In [None]:
for filename in forest_files:
 shutil.copyfile(
 os.path.join("base/EuroSAT/Forest",filename),
 os.path.join("base/class",filename)
 )
for filename in industrial_files:
 shutil.copyfile(
 os.path.join("base/EuroSAT/Industrial",filename),
 os.path.join("base/class",filename)
 )

## Images generated by Stable Diffusion
Before we start fine-tuning, lets have a look at the default images generated by Stable Diffusion. We use Stable Diffusion (1.5) to generate satellite images of the `Highway` class. 

We leverate the [Diffusers](https://huggingface.co/docs/diffusers/index) library from Huggingface for the generation.

In [None]:
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch

pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

In [None]:
img_list = pipe(["Sentinel 2 satellite image of a highway"]*10, num_inference_steps=25).images

In [None]:
image_grid([x.resize((64,64)) for x in img_list], 2,5)

In [None]:
import gc
from numba import cuda
del(pipe)
gc.collect()
torch.cuda.empty_cache()

# device = cuda.get_current_device()
# device.reset()

## Actual highway class images from EuroSAT 

In [None]:
actual_img = [Image.open("EuroSAT/Highway/Highway_{}.jpg".format(str(i))) for i in range(1,11)]
image_grid([x.resize((64,64)) for x in actual_img], 2,5)

We see that in terms of color and style there is a significant difference between Stable Diffusion direct generated images and the actual EuroSAT dataset images

## Fine-tune Stable Diffusion with LORA and DreamBooth
We want to fine-tune our text-to-image model to learn how to generate the right type of satellite images. To do so, we utilize two recent innovations, Dreambooth and LoRA. Dreambooth is a new method to allow models to learn to generate images that fit the distinct characteristics of the `instance` relative to the larger `class`. Low rank adapters (LoRA) allows for fast model training by drastically reducing the number of training parameters. We utilize the scripts found [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README.md).

To enable Stable Diffusion to learn a new `instance`, we use a unique (and short) token/word to represent the new `instance`. In our case, we use the token/word `sks` that is commonly used, and is not close in terms of character sequence to other meaningful words. `sks` is commonly used in many tutorials for Stable Diffusion fine-tuning.

We first install diffusers library

In [None]:
!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora.py

Next, we run the fine-tuning code. This runs fine-tuning locally within the instance of the notebook. The [accelerate](https://github.com/huggingface/accelerate) library makes running the PyTorch code on multi-GPU easy.

In [None]:
!accelerate launch train_dreambooth_lora.py \
 --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
 --instance_data_dir="EuroSAT/Highway" \
 --output_dir=trained_model \
 --instance_prompt="Sentinel 2 satellite image of sks" \
 --resolution=256 \
 --train_batch_size=1 \
 --gradient_accumulation_steps=1 \
 --checkpointing_steps=100 \
 --learning_rate=1e-4 \
 --report_to="tensorboard" \
 --lr_scheduler="constant" \
 --lr_warmup_steps=0 \
 --with_prior_preservation \
 --class_data_dir="base/class" \
 --class_prompt="Sentinel 2 satellite image" \
 --max_train_steps=800 \
 --seed="0" 

## Visualizing results
Now that the model is trained, let's compare:
1. Stable Diffusion generated images without fine-tuning
2. Stable Diffusion generated images with LoRA and Dreambooth fine-tuning
3. Original EuroSAT images

In [None]:
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch

Lets look at the images generated without fine-tuning

In [None]:
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

In [None]:
img_list = pipe(["Sentinel 2 satellite image of a highway"]*3, num_inference_steps=25).images
image_grid([x.resize((128,128)) for x in img_list], 1,3)

Next, we look at the images created after fine-tuning

In [None]:
pipe.unet.load_attn_procs("./trained_model/checkpoint-800")

In [None]:
img_list = pipe(["Sentinel 2 satellite image of sks"]*3, num_inference_steps=25).images

In [None]:
image_grid([x.resize((128,128)) for x in img_list], 1,3)

And eventually we look at the original images

In [None]:
from PIL.ImageOps import exif_transpose
actual_img = [exif_transpose(Image.open("EuroSAT/Highway/Highway_{}.jpg".format(str(i)))) for i in range(1,4)]
image_grid([x.resize((128,128)) for x in actual_img], 1,3)

That's it! This finishes the notebook. In this notebook, we have seen how fine-tuning stable diffusion with custom images increases the quality of generated images. 

## Cleanup
After you close the notebook, please ensure that you close the instance as well using the icon (black square within white circle) on the left.