FULL_TRAINING = False import os import sys import json import logging import argparse import torch from torch import nn import numpy as np import pandas as pd import evaluate from tqdm import tqdm from IPython.display import display, HTML from pynvml import * from torch.utils.data import Dataset from PIL import Image from transformers import ( AutoTokenizer, VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, default_data_collator ) from sklearn.model_selection import train_test_split from transformers.trainer_utils import get_last_checkpoint from datasets import load_dataset, load_metric logging.basicConfig( level=logging.INFO, format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) cer_metric = evaluate.load("cer") wer_metric = evaluate.load("wer") class OCRDataset(Dataset): def __init__(self, dataset_dir, df, processor, tokenizer, max_target_length=32): self.dataset_dir = dataset_dir self.df = df self.processor = processor self.max_target_length = max_target_length self.tokenizer = tokenizer def __len__(self): return len(self.df) def __getitem__(self, idx): # get file name + text file_name = self.df['file_name'][idx] text = self.df['text'][idx] # prepare image (i.e. resize + normalize) image = Image.open(os.path.join(self.dataset_dir, file_name)).convert("RGB") pixel_values = self.processor(image, return_tensors="pt").pixel_values # add labels (input_ids) by encoding the text labels = self.tokenizer(text, padding="max_length", stride=32, truncation=True, max_length=self.max_target_length).input_ids # important: make sure that PAD tokens are ignored by the loss function labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels] encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding def print_gpu_utilization(): nvmlInit() handle = nvmlDeviceGetHandleByIndex(0) info = nvmlDeviceGetMemoryInfo(handle) print(f"GPU memory occupied: {info.used//1024**2} MB.") def print_summary(result): print(f"Time: {result.metrics['train_runtime']:.2f}") print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") print_gpu_utilization() def compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"cer": cer, "wer": wer} def parser_args(train_notebook=False): parser = argparse.ArgumentParser() # Default Setting parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--max_length", type=int, default=64) parser.add_argument("--stride", type=int, default=32) parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--logging_steps", type=int, default=100) parser.add_argument("--learning_rate", type=str, default=4e-5) parser.add_argument("--disable_tqdm", type=bool, default=False) parser.add_argument("--fp16", type=bool, default=True) parser.add_argument("--debug", type=bool, default=False) # SageMaker Container environment parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) parser.add_argument("--train_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') if train_notebook: args = parser.parse_args([]) else: args = parser.parse_args() return args def main(): is_sm_container = True if os.environ.get('SM_CURRENT_HOST') is None: is_sm_container = False train_dir = 'train' model_dir = 'model' output_data_dir = 'data' src_dir = '/'.join(os.getcwd().split('/')[:-1]) #src_dir = os.getcwd() os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}' os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}' os.environ['SM_NUM_GPUS'] = str(1) os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}' args = parser_args(train_notebook=True) print(args) if os.environ.get('SM_CURRENT_HOST') is None: args.chkpt_dir = 'chkpt' n_gpus = torch.cuda.device_count() if os.getenv("SM_NUM_GPUS")==None: print("Explicitly specifying the number of GPUs.") os.environ["GPU_NUM_DEVICES"] = n_gpus else: os.environ["GPU_NUM_DEVICES"] = os.environ["SM_NUM_GPUS"] logger.info("***** Arguments *****") logger.info(''.join(f'{k}={v}\n' for k, v in vars(args).items())) os.makedirs(args.chkpt_dir, exist_ok=True) os.makedirs(args.model_dir, exist_ok=True) os.makedirs(args.output_data_dir, exist_ok=True) df = pd.read_csv(f'{args.train_dir}/labels.txt', header=None, sep="^(\d+\.jpg)", engine='python') df = df.drop(df.columns[[0]], axis=1) df.rename(columns={1: "file_name", 2: "text"}, inplace=True) df['text'] = df['text'].str.strip() # Just for hands-on lab if not FULL_TRAINING: df = df.sample(n=100, random_state=42) if FULL_TRAINING: vision_hf_model = 'facebook/deit-base-distilled-patch16-384' nlp_hf_model = "klue/roberta-base" # Reference: https://github.com/huggingface/transformers/issues/15823 # initialize the encoder from a pretrained ViT and the decoder from a pretrained BERT model. # Note that the cross-attention layers will be randomly initialized, and need to be fine-tuned on a downstream dataset model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(vision_hf_model, nlp_hf_model) tokenizer = AutoTokenizer.from_pretrained(nlp_hf_model) else: trocr_model = 'daekeun-ml/ko-trocr-base-nsmc-news-chatbot' model = VisionEncoderDecoderModel.from_pretrained(trocr_model) tokenizer = AutoTokenizer.from_pretrained(trocr_model) train_df, test_df = train_test_split(df, test_size=0.1) train_df.reset_index(drop=True, inplace=True) test_df.reset_index(drop=True, inplace=True) processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") train_dataset = OCRDataset( dataset_dir=args.train_dir, df=train_df, tokenizer=tokenizer, processor=processor, max_target_length=args.max_length ) eval_dataset = OCRDataset( dataset_dir=args.train_dir, df=test_df, tokenizer=tokenizer, processor=processor, max_target_length=args.max_length ) print("Number of training examples:", len(train_dataset)) print("Number of validation examples:", len(eval_dataset)) # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = tokenizer.cls_token_id model.config.pad_token_id = tokenizer.pad_token_id model.config.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = tokenizer.sep_token_id model.config.max_length = args.max_length model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="steps", per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, num_train_epochs=args.epochs, fp16=args.fp16, learning_rate=float(args.learning_rate), output_dir=args.chkpt_dir, #logging_dir="./logs", #logging_steps=10, save_steps=5000, eval_steps=5000, ) # instantiate trainer trainer = Seq2SeqTrainer( model=model, tokenizer=tokenizer, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, ) trainer.train() # Saves the model to s3 uses os.environ["SM_MODEL_DIR"] to make sure checkpointing works tokenizer.save_pretrained(args.model_dir) trainer.save_model(output_dir=args.model_dir) def _mp_fn(index): # For xla_spawn (TPUs) main() if __name__ == "__main__": main()