from transformers import ( Wav2Vec2ForCTC, Trainer, TrainingArguments, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor) from datasets import load_from_disk, load_metric from dataclasses import dataclass from typing import Dict, List, Optional, Union import logging import sys import argparse import os import torch import numpy as np import boto3 if __name__ == "__main__": parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--train_batch_size", type=int, default=8) parser.add_argument("--eval_batch_size", type=int, default=8) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base") parser.add_argument("--learning_rate", type=str, default=1e-4) parser.add_argument("--weight_decay", type=str, default=0.005) parser.add_argument("--vocab_url", type=str) # Data, model, and output directories parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"]) args, _ = parser.parse_known_args() # set up logging logger = logging.getLogger(__name__) logging.basicConfig( level=logging.getLevelName("INFO"), handlers=[logging.StreamHandler(sys.stdout)], format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # load datasets print(args.training_dir) print(args.test_dir) train_dataset = load_from_disk(args.training_dir) test_dataset = load_from_disk(args.test_dir) logger.info(f" loaded train_dataset length is: {len(train_dataset)}") logger.info(f" loaded test_dataset length is: {len(test_dataset)}") # add tokenizer, feature extractor and processor s3 = boto3.client('s3') vocab_url = args.vocab_url s3.download_file(vocab_url.split('/')[2],vocab_url.split('/')[3]+'/vocab.json','vocab.json') tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) # the data collator class is for dynamic padding and speed up data processing in CPU GPU. The code is from: # https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81 @dataclass class DataCollatorCTCWithPadding: """ Data collator that will dynamically pad the inputs received. Args: processor (:class:`~transformers.Wav2Vec2Processor`) The processor used for proccessing the data. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (:obj:`int`, `optional`): Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). max_length_labels (:obj:`int`, `optional`): Maximum length of the ``labels`` returned list and optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). """ processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lenghts and need # different padding methods input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding=self.padding, max_length=self.max_length_labels, pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt", ) # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) # compute metrics wer_metric = load_metric("wer") def compute_metrics(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis=-1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens=False) wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} model = Wav2Vec2ForCTC.from_pretrained( args.model_name, gradient_checkpointing=True, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, ) model.freeze_feature_extractor() # define training args training_args = TrainingArguments( output_dir=args.model_dir, group_by_length=True, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, evaluation_strategy="steps", num_train_epochs=args.epochs, fp16=True, # enable mixed-precision training save_steps=50, eval_steps=50, logging_steps=50, learning_rate=float(args.learning_rate), weight_decay=float(args.weight_decay), warmup_steps=args.warmup_steps, save_total_limit=2, logging_dir=f"{args.output_data_dir}/logs", ) # create Trainer instance trainer = Trainer( model=model, data_collator=data_collator, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=processor.feature_extractor, ) # train model trainer.train() # evaluate model eval_result = trainer.evaluate(eval_dataset=test_dataset) # writes eval result to file which can be accessed later in s3 ouput with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: print(f"***** Eval results *****") for key, value in sorted(eval_result.items()): writer.write(f"{key} = {value}\n") # Saves the model to s3 trainer.save_model(args.model_dir) tokenizer.save_pretrained(args.model_dir)