import argparse from functools import partial from pathlib import Path import os import datetime import datasets import evaluate import nltk import numpy as np import torch from accelerate import Accelerator from accelerate.utils import LoggerType from datasets import concatenate_datasets from nltk.tokenize import sent_tokenize from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed, ) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_model_name_or_path", type=str, default="google/flan-t5-large", # required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--train_dataset_path", type=str, default="/opt/ml/input/data/train", # required=True, help="Path to the dataset.", ) parser.add_argument("--lr", type=float, default=3e-3, help="Learning rate.") parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") parser.add_argument("--seed", type=int, default=42, help="Seed.") parser.add_argument( "--subsample", type=int, default=25, help="percentage of training data to use." ) parser.add_argument( "--model_dir", type=str, default="/opt/ml/model", help="Model dir." ) parser.add_argument( "--tensorboard_dir", type=str, default="/opt/ml/output/tensorboard", help="Tensorboard dir.", ) parser.add_argument("--log_steps", type=int, default=10, help="Log interval steps.") args = parser.parse_args() return args def preprocess_function( sample, tokenizer, max_source_length, max_target_length, padding="max_length", ): # add prefix to the input for t5 inputs = ["summarize: " + item for item in sample["dialogue"]] # tokenize inputs model_inputs = tokenizer( inputs, max_length=max_source_length, padding=padding, truncation=True ) # Tokenize targets with the `text_target` keyword argument labels = tokenizer( text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True, ) # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # padding in the loss. if padding == "max_length": labels["input_ids"] = [ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] ] model_inputs["labels"] = labels["input_ids"] return model_inputs def collate_fn(examples, tokenizer): return tokenizer.pad(examples, padding="longest", return_tensors="pt") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(sent_tokenize(pred)) for pred in preds] labels = ["\n".join(sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics(preds, labels, tokenizer): metric = evaluate.load("rouge") decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) result = metric.compute( predictions=decoded_preds, references=decoded_labels, use_stemmer=True ) result = {k: round(v * 100, 4) for k, v in result.items()} prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds ] result["gen_len"] = np.mean(prediction_lens) return result def main(args): model_name_or_path = args.pretrained_model_name_or_path dataset_path = Path(args.train_dataset_path) lr = args.lr num_epochs = args.num_epochs batch_size = args.batch_size seed = args.seed tb_log_dir = args.tensorboard_dir tb_log_interval = args.log_steps accelerator = Accelerator(log_with=LoggerType.TENSORBOARD, project_dir=tb_log_dir) accelerator.init_trackers(".", init_kwargs={"tensorboard": {"flush_secs": 30}}) config = { "lr": lr, "num_epochs": num_epochs, "batch_size": batch_size, "seed": seed, } if accelerator.is_main_process: # workaround for hparams not showing up in tensorboard if no metrics are logged # https://github.com/tensorflow/tensorboard/issues/2942 tb_tracker = [ tracker for tracker in accelerator.trackers if tracker.name == "tensorboard" ][0] # log hp_metric till the issue in TB is fixed tb_tracker.writer.add_hparams(config, {"hp_metric": 0}, run_name=".") tb_tracker.writer.flush() with accelerator.main_process_first(): # configure evaluation metrics # this should run in main process first to download the punkt corpus nltk.download("punkt") set_seed(seed) # read the dataset ds = datasets.Dataset.from_json((dataset_path / "dialogsum.train.jsonl").as_posix()) # take a subsample of the data ds = datasets.Dataset.from_pandas( ds.to_pandas().sample( frac=args.subsample / 100, random_state=seed, ignore_index=True ) ) # split into train and test dataset = ds.train_test_split(test_size=0.1) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map( lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary", "fname", "topic"], ) max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]]) tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map( lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary", "fname", "topic"], ) max_target_length = max([len(x) for x in tokenized_targets["input_ids"]]) preprocess = partial( preprocess_function, tokenizer=tokenizer, max_source_length=max_source_length, max_target_length=max_target_length, ) # with accelerator.main_process_first(): processed_datasets = dataset.map( preprocess, batched=True, num_proc=1, load_from_cache_file=True, remove_columns=["dialogue", "summary", "fname", "topic"], desc="Running tokenizer on dataset", ) accelerator.wait_for_everyone() train_dataset = processed_datasets["train"] test_dataset = processed_datasets["test"] collate = partial(collate_fn, tokenizer=tokenizer) train_dataloader = DataLoader( train_dataset, num_workers=4, shuffle=True, collate_fn=collate, batch_size=batch_size, pin_memory=True, ) test_dataloader = DataLoader( test_dataset, collate_fn=collate, batch_size=batch_size * 8, pin_memory=True ) # create the model model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) peft_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, # size of the LoRA attention dimension lora_alpha=32, # the gradients will be scaled by r / lora_alpha (similar to tuning the learning rate) lora_dropout=0.1, # drop out rate for the LoRA attention ) model = get_peft_model(model, peft_config) model.print_trainable_parameters() # create the optimizer optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # create an lr scheduler lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=0, num_training_steps=(len(train_dataloader) * num_epochs), ) # prepare model for training ( model, train_dataloader, test_dataloader, optimizer, lr_scheduler, ) = accelerator.prepare( model, train_dataloader, test_dataloader, optimizer, lr_scheduler, ) is_ds_zero_3 = False if getattr(accelerator.state, "deepspeed_plugin", None): is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 total_steps = 0 for epoch in range(num_epochs): model.train() total_loss = 0 for step, batch in enumerate(tqdm(train_dataloader)): # gradient accumulation with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss total_loss += loss.detach().float() accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() train_loss = total_loss / (step + 1) train_perplexity = torch.exp(train_loss) # log to tensorboard if step % tb_log_interval == 0: accelerator.log( { "training_loss": train_loss.item(), "train_perplexity": train_perplexity.item(), }, step=total_steps, ) total_steps += 1 train_epoch_loss = total_loss / len(train_dataloader) train_epoch_perplexity = torch.exp(train_epoch_loss) accelerator.print(f"{epoch=}: {train_epoch_perplexity=} {train_epoch_loss=}") model.eval() eval_preds = [] eval_labels = [] max_new_eval_tokens = 100 for _, batch in enumerate(tqdm(test_dataloader)): labels = batch.pop("labels") with torch.no_grad(): outputs = accelerator.unwrap_model(model).generate( **batch, synced_gpus=is_ds_zero_3, max_new_tokens=max_new_eval_tokens, ) # synced_gpus=True for DS-stage 3 # pad outputs to max length outputs = torch.nn.functional.pad( outputs, (0, max_new_eval_tokens - outputs.shape[1]), "constant", tokenizer.pad_token_id ) preds = accelerator.gather(outputs).detach().cpu().numpy() labels = accelerator.gather(labels).detach().cpu().numpy() eval_preds.extend(preds) eval_labels.extend(labels) if accelerator.is_main_process: eval_preds = np.stack(eval_preds) eval_labels = np.stack(eval_labels) metrics = compute_metrics(eval_preds, eval_labels, tokenizer) accelerator.print(metrics) accelerator.log(metrics, step=total_steps) accelerator.wait_for_everyone() peft_model_id = ( f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}" ) checkpoint_name = f"{args.model_dir}/{peft_model_id}/adapter_model.bin" if accelerator.is_main_process: model.save_pretrained(f"{args.model_dir}/{peft_model_id}") accelerator.save( get_peft_model_state_dict( model, state_dict=accelerator.get_state_dict(model) ), checkpoint_name, ) accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args)