# Part One: Download data for Stable Diffuion with SageMaker Training
In this notebook you'll use job parallelism to download more than 100GB of data from the Laion-5B dataset used with Stable Diffusion. This is broken into a few parts:
1. Download the parquet files, these will be sent to S3 directly from your SageMaker job.
2. Inspect the parquet files locally. You will want to upgrade your Studio instance for this.
3. Use job parallelism to run many instances in parallel, each downloading all of the image/text pairs for one parquet file.


Special note, this notebook is designed to work nicely with SageMaker Studio. You'll want to get comfortable upgrading and downgrading your instances here.

### Step 0. Update AWS botocore to enable SM Warm Pools

In [None]:
%pip install --upgrade sagemaker
%pip install boto3 --upgrade
%pip install botocore --upgrade

### Step 1. Write a script to download the parquet files

Using commands suggested [by Romain's original package here](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion5B.md#normal). 

In [None]:
!mkdir bootcamp_scripts

In [None]:
%%writefile bootcamp_scripts/parquet_download.py

import argparse
import os

def parse_args():
 
 parser = argparse.ArgumentParser() 
 
 parser.add_argument("--bucket", type=str, default=os.environ["SM_HP_BUCKET"])
 
 parser.add_argument("--num_files", type=int, default=os.environ["SM_HP_NUM_FILES"])

 args = parser.parse_args()
 
 return args

def get_part_ids(num_files):

 part_ids = []
 
 if num_files > 127:
 print ('error, currently Laion-5B only has 127 parquet files')
 return []
 
 for idx in range(0, num_files):
 part_id = '{}'.format(idx).zfill(5)
 part_ids.append(part_id)

 return part_ids

def download_parquet(bucket, num_files):

 part_ids = get_part_ids(num_files)

 for p_id in part_ids:

 cmd = 'wget https://huggingface.co/datasets/laion/laion2B-en-joined/resolve/main/part-{}-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet -O - | aws s3 cp - s3://{}/metadata/laion2B-en-joined/part-{}-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet'.format(p_id, bucket, p_id)

 os.system(cmd)

if __name__ == "__main__":
 
 args = parse_args()
 
 download_parquet(args.bucket, args.num_files)
 

### Step 2. Run on SageMaker Training

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

bucket = sess.default_bucket()

hyperparameters = {"bucket":bucket, "num_files":10}

estimator = PyTorch(
 entry_point="parquet_download.py",
 base_job_name="sd-parquet-download",
 role=role,
 source_dir="bootcamp_scripts",
 # configures the SageMaker training resource, you can increase as you need
 instance_count=1,
 instance_type="ml.m5.large",
 py_version="py38",
 framework_version = '1.10',
 sagemaker_session=sess,
 debugger_hook_config=False,
 hyperparameters=hyperparameters,
 # enable warm pools for 60 minutes, useful for debugging
 keep_alive_period_in_seconds = 60 * 60)

In [None]:
estimator.fit(wait=False)

### Step 3. Analyze parquet response
You can check your S3 bucket to watch the parquet files come in. Once you have at least part-00000 downloaded, you can procede to analyze it here.

In [None]:
path = 's3://{}/metadata/laion2B-en-joined/part-00000-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet'.format(bucket)

In [None]:
!mkdir parquet
!aws s3 cp {path} parquet/

In [None]:
import pandas as pd

parquet_file = 'parquet/part-00000-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet'

# please make sure you are using a larger instance for your notebook here, as the parquet file is quite large
# if your kernel dies, it's because you need to upgrade to one with more cores
# I believe the smallest instance you can use here is the ml.m5.2xlarge
df = pd.read_parquet(parquet_file)

In [None]:
df.head()

### Step 4. Use job parallelism to download all of the image/text pairs

Now, to scale this out, we need to send each parquet file as an input to the job. Then, the training script will use large machines, many running at the same time, to download all of the images. Each of these will then be copied back to S3. 

In [None]:
import sagemaker

sess = sagemaker.Session()

bucket = sess.default_bucket()

# this should point to the parent S3 directory with all of your parquet files
parquet_path = 's3://{}/metadata/laion2B-en-joined/'.format(bucket)

!aws s3 ls {parquet_path} >> parquet_list.txt

In [None]:
parquet_list = []

with open ('parquet_list.txt') as f:
 
 for row in f.readlines():
 r = row.strip()
 parquet_list.append(r.split(' ')[-1])

In [None]:
# take a look at this list and make sure all the parquet files seem valid. each of these will serve as an input to its own SageMaker job
parquet_list

In [None]:
num_files = len(parquet_list)

print ('About to run {} SM jobs to download all of your parquet files'.format(num_files))

In [None]:
import os

def is_open(s3_path):
 # checks to see if there is anything in the specific S3 path
 # returns True if nothing is there
 cmd = 'aws s3 ls {}'.format(s3_path)
 res = os.system(cmd)
 if res == 256:
 return True
 else:
 return False

#### Define job parameters

In [None]:
%%writefile bootcamp_scripts/requirements.txt
img2dataset
s3fs

In [None]:
%%writefile bootcamp_scripts/download_data.py

from img2dataset import download
import shutil
import os
import multiprocessing
import threading
import argparse

def parse_args():
 
 parser = argparse.ArgumentParser()

 # parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
 
 parser.add_argument("--cores", type=int, default=multiprocessing.cpu_count())

 parser.add_argument("--threads", type=int, default=threading.active_count())
 
 parser.add_argument("--parquet", type=str, default=os.environ["SM_CHANNEL_PARQUET"])
 
 parser.add_argument("--file_name", type=str, default=os.environ["SM_HP_FILE_NAME"])
 
 parser.add_argument("--bucket", type=str, default=os.environ["SM_MODULE_DIR"].split('/')[2])
 
 args = parser.parse_args()
 
 return args

def prep_system():
 
 args = parse_args()
 
 # send joint path and file name
 url_list = "{}/{}".format(args.parquet, args.file_name)
 
 part_number = args.file_name.split('-')[1]

 # point to output path in S3
 s3_output = "s3://{}/data/part-{}/".format(args.bucket, part_number)
 
 return args, url_list, s3_output

 
if __name__ == "__main__":
 
 args, url_list, s3_output = prep_system()
 
 download(
 processes_count=args.cores,
 thread_count=args.threads,
 # takes a single parquet file
 url_list=url_list,
 image_size=256,
 # copies to S3 directly, bypassing local disk
 output_folder=s3_output,
 # each image / caption pair is a tarball
 output_format="webdataset",
 input_format="parquet",
 url_col="URL",
 caption_col="TEXT",
 enable_wandb=False,
 number_sample_per_shard=1000,
 distributor="multiprocessing",
 )
 

#### Loop through parquet files in S3 and run SageMaker training jobs

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

bucket = sess.default_bucket()

In [None]:
def get_estimator(part_number, p_file, output_dir):
 
 # this passes the name of your parquet file as an input to the job
 hyperparameters = {"file_name": p_file}

 estimator = PyTorch(entry_point="download_data.py",
 base_job_name="laion-part-{}".format(part_number),
 role=role,
 source_dir="bootcamp_scripts",
 # configures the SageMaker training resource, you can increase as you need
 instance_count=1,
 instance_type="ml.c5.18xlarge",
 py_version="py36",
 framework_version = '1.8',
 sagemaker_session=sagemaker_session,
 volume_size = 250,
 debugger_hook_config=False,
 hyperparameters=hyperparameters,
 output_path = output_dir)
 return estimator

for p_file in parquet_list:
 
 part_number = p_file.split('-')[1]

 output_dir = "s3://{}/data/part-{}/".format(bucket, part_number)

 if is_open(output_dir):

 est = get_estimator(part_number, p_file, output_dir)

 est.fit({"parquet":"s3://{}/metadata/laion2B-en-joined/{}".format(bucket, p_file)}, wait=False)


### Conclusion and next steps
And that's a wrap! In this notebook you downloaded the metadata for the Laion-5B dataset, and then used job parallelism on SageMaker to run a full job for each parquet file.

Your next task is to configure FSx for Lustre, and ensure your training script works nicely with this and SageMaker.