"""Transformer-based text classification on SageMaker with Hugging Face""" # Python Built-Ins: import argparse import logging import os import sys from typing import List, Optional # External Dependencies: import datasets #from datasets import disable_progress_bar as disable_datasets_progress_bar from transformers import ( AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, DataCollatorWithPadding, ) from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Set up logging: logging.basicConfig( level=logging.getLevelName("INFO"), handlers=[logging.StreamHandler(sys.stdout)], format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) datasets.disable_progress_bar() # Too noisy on conventional log streams # Factoring your code out into smaller helper functions can help with debugging: def parse_args(): """Parse hyperparameters and data args from CLI arguments and environment variables""" parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. parser.add_argument("--model_id", type=str, required=True) parser.add_argument("--class_names", type=lambda s: s.split(","), required=True) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--train_max_steps", type=int, default=-1) parser.add_argument("--train_batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=64) parser.add_argument("--fp16", type=int, default=1) # Data, model, and output folders are set by combination of CLI args and env vars: parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN")) parser.add_argument("--test", type=str, default=os.environ.get("SM_CHANNEL_TEST")) parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR")) parser.add_argument("--output_data_dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR")) # parser.add_argument("--n_gpus", type=int, default=os.environ.get("SM_NUM_GPUS")) args, _ = parser.parse_known_args() return args def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="micro") acc = accuracy_score(labels, preds) return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} def get_model(model_id: str, class_names: List[str]) -> ( AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding ): """Set up tokenizer, model, data_collator from job parameters""" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=len(class_names) ) model.config.label2id = {name: ix for ix, name in enumerate(class_names)} model.config.id2label = {ix: name for ix, name in enumerate(class_names)} data_collator = DataCollatorWithPadding(tokenizer=tokenizer) return tokenizer, model, data_collator def load_datasets(tokenizer: AutoTokenizer, train_dir: str, test_dir: Optional[str] = None) -> ( datasets.Dataset, Optional[datasets.Dataset] ): """Load and pre-process training (+ validation?) dataset(s)""" def preprocess(batch): """Tokenize and pre-process raw examples for training/validation""" result = tokenizer(batch["title"], truncation=True) result["label"] = batch["category"] return result raw_train_dataset = datasets.load_dataset( "csv", data_files=[os.path.join(train_dir, f) for f in os.listdir(train_dir)], column_names=["category", "title", "content"], split=datasets.Split.ALL, ) train_dataset = raw_train_dataset.map( preprocess, batched=True, batch_size=1000, remove_columns=raw_train_dataset.column_names ) logger.info(f"Loaded train_dataset length is: {len(train_dataset)}") if test_dir: # test channel is optional: raw_test_dataset = datasets.load_dataset( "csv", data_files=[os.path.join(test_dir, f) for f in os.listdir(test_dir)], column_names=["category", "title", "content"], split=datasets.Split.ALL, ) test_dataset = raw_test_dataset.map( preprocess, batched=True, batch_size=1000, remove_columns=raw_test_dataset.column_names ) logger.info(f"Loaded test_dataset length is: {len(test_dataset)}") else: test_dataset = None logger.info("No test_dataset provided") return train_dataset, test_dataset # Only run this main block if running as a script (e.g. in training), not when imported as a module # (which would be the case if used at inference): if __name__ == "__main__": # Load job parameters: args = parse_args() training_args = TrainingArguments( max_steps=args.train_max_steps, num_train_epochs=args.epochs, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, fp16=bool(args.fp16), evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, disable_tqdm=True, # Interactive progress bars too noisy on conventional log streams # You could save checkpoints & logs under args.output_data_dir to upload them, but it # increases job run time by a few minutes: output_dir="/tmp/transformers/checkpoints", logging_dir="/tmp/transformers/logs", ) # Load tokenizer/model/collator: tokenizer, model, collator = get_model(model_id=args.model_id, class_names=args.class_names) # Load and pre-process the dataset: train_dataset, test_dataset = load_datasets( tokenizer=tokenizer, train_dir=args.train, test_dir=args.test, ) # Create Trainer instance trainer = Trainer( model=model, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, data_collator=collator, ) # Train the model trainer.train() # Save the model output trainer.save_model(args.model_dir) # Evaluate the final model and save a report, if test dataset provided: if test_dataset: eval_result = trainer.evaluate(eval_dataset=test_dataset) # The 'output' folder will also (separately from model) get uploaded to S3 by SageMaker: if args.output_data_dir: os.makedirs(args.output_data_dir, exist_ok=True) with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: print("***** Eval results *****") for key, value in sorted(eval_result.items()): writer.write(f"{key} = {value}\n")