#!/usr/bin/env python """ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: MIT-0 """ import os import sys import ast import argparse from collections import Counter import numpy as np import json import datetime import logging import time from sklearn.metrics import recall_score, precision_score, f1_score, roc_auc_score import pandas as pd import torch import torch.nn as nn from data import VoxForgeDataset, create_audio_transform, load_waveform from models import Classifier from utils import print_and_log, classwise_f1 logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) def _create_weighted_sampler(train_dataset, index_to_name, name_to_index, hparams): """ Creates a weighted sampler that can be passed to a torch DataLoader""" if hparams['source_weighted_normalization']: logger.debug('Using source weighted and class weighted random sampling') source_weights = dict() normalized_class_ctr = dict() # for each language, normalized the sources such that each source has equal number of samples for lang in train_dataset.data_df['class'].unique(): source_ctr = dict(train_dataset.data_df[train_dataset.data_df['class'] == lang]['source'].value_counts()) max_ct = max(source_ctr.values()) source_weights[lang] = dict([(k, (max_ct / v)) for k, v in source_ctr.items()]) # class_ctr normalized such that all sources, within each class, have the same amount of samples normalized_class_ctr[lang] = max_ct * len(source_ctr) normalized_class_weights = dict([(k, max(normalized_class_ctr.values()) / v) for k, v in normalized_class_ctr.items()]) # sample weight = normalized class weight * per language source weight normalized_source_weights = dict() for lang, weights in source_weights.items(): normalized_source_weights.update(dict([(k, normalized_class_weights[lang]*v) for k, v in weights.items()])) sample_weights = [normalized_source_weights[train_dataset.data_df.iloc[i]['source']] for i in range(len(train_dataset))] else: logger.debug('Using class weighted random sampling') class_ctr = dict(train_dataset.data_df['class'].value_counts()) class_weights = np.array([max(class_ctr.values())/class_ctr[index_to_name[i]] for i in range(len(class_ctr))]) sample_weights = [class_weights[name_to_index[x]] for x in train_dataset.data_df['class'].tolist()] sampler = torch.utils.data.sampler.WeightedRandomSampler(sample_weights, len(sample_weights)) def _get_device(): """ Get GPU device if available, otherwise CPU """ if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") return device def _get_hparams(): """ Load user hyperparameters """ # default hyperparameter values hparams = { 'checkpoint' : None, 'sample_rate' : 16000, 'n_samples' : 80000, 'languages' : ['en', 'it', 'es', 'fr', 'ru', 'de'], 'feature_type' : 'mel', 'resize' : None, 'normalize' : True, 'standardize' : True, 'standardize_mean' : 0.4630, 'standardize_std' : 0.2031, 'random_pitch_shift' : True, 'pitch_shift_range' : (-5, 5), 'random_crop' : True, 'frequency_masking' : True, 'time_masking' : True, 'source_weighted_normalization' : True, 'source_normalized_evaluation' : True, 'epochs' : 50, 'batch_size' : 32, 'lr' : 0.0001, 'hidden_dim' : 512, 'logging_iters' : 100, 'early_stopping_epochs' : 5, 's3_prefix' : None, } # load values defined in Estimator user_hparams = json.load(open('/opt/ml/input/config/hyperparameters.json', 'r')) # verify valid hyperparameters provided by user if not (user_hparams.keys() <= hparams.keys()): raise Exception('hyperparameters not valid. Valid hyperparameters : {}'.format(list(hparams.keys()))) hparams.update(user_hparams) # evaluate hparam string values for k, v in hparams.items(): try: hparams[k] = ast.literal_eval(v) except: continue return hparams def save_model(model, model_dir, example_input): """ Save model using jit """ path = os.path.join(model_dir, 'model.pt') traced_model = torch.jit.trace(model.cpu(), example_input) traced_model.save(path) def train(args, hparams): """ Trains model """ device = _get_device() logger.debug('Device - {}'.format(device)) train_dataset = VoxForgeDataset( os.path.join(args.train, 'train_manifest.csv'), sample_rate=hparams['sample_rate'], n_samples=hparams['n_samples'], feature_type=hparams['feature_type'], resize=hparams['resize'], normalize=hparams['normalize'], standardize=hparams['standardize'], standardize_mean=hparams['standardize_mean'], standardize_std=hparams['standardize_std'], random_pitch_shift=hparams['random_pitch_shift'], pitch_range=hparams['pitch_shift_range'], random_crop=hparams['random_crop'], frequency_masking=hparams['frequency_masking'], time_masking=hparams['time_masking'], languages=hparams['languages'], s3_prefix=hparams['s3_prefix'] ) val_dataset = VoxForgeDataset( os.path.join(args.validation, 'val_manifest.csv'), sample_rate=hparams['sample_rate'], n_samples=hparams['n_samples'], feature_type=hparams['feature_type'], resize=hparams['resize'], normalize=hparams['normalize'], standardize=hparams['standardize'], standardize_mean=hparams['standardize_mean'], standardize_std=hparams['standardize_std'], include_source=True, languages=hparams['languages'], s3_prefix=hparams['s3_prefix'] ) name_to_index = dict([(k, i) for i, k in enumerate(sorted(train_dataset.data_df['class'].unique()))]) index_to_name = dict([(v, k) for k, v in name_to_index.items()]) sampler = _create_weighted_sampler(train_dataset, index_to_name, name_to_index, hparams) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams['batch_size'], sampler=sampler) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False) model = Classifier(n_classes=len(index_to_name), h_dim=hparams['hidden_dim']).to(device) logger.info('# of parameters : {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr']) with open(os.path.join(args.model_dir, 'hparams.json'), 'w') as fp: json.dump(hparams, fp) with open(os.path.join(args.model_dir, 'index_to_name.json'), 'w') as fp: json.dump(index_to_name, fp) logger.debug('Model - {}'.format(model.__class__.__name__)) for k, v in hparams.items(): logger.debug('{} : {}'.format(k, v)) # Begin training loop loss_fn = nn.CrossEntropyLoss(reduction='none') ckpt_weights_path = None since_best = 0 done = False if hparams['checkpoint']: ckpt = torch.load(hparams['checkpoint']) best_loss = ckpt['best_loss'] epoch = ckpt['epoch'] itr = ckpt['itr'] model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) logger.info('Continuing training from: {}'.format(hparams['checkpoint'])) else: best_loss = 1e10 epoch = 0 itr = 0 for epoch in range(epoch, hparams['epochs']): if done: break model.train() for batch in train_dataloader: if done: break if torch.isnan(batch[0]).any(): continue itr += 1 features = batch[0].to(device) labels = torch.tensor([name_to_index[i] for i in batch[1]], dtype=torch.long, device=device) outputs = model(features) loss = loss_fn(outputs, labels).mean() optimizer.zero_grad() loss.backward() optimizer.step() if (itr % hparams['logging_iters'] == 0) or (itr == 1): acc = (outputs.argmax(-1) == labels).float().mean() logger.info('[{}, {:5d}] loss : {:.4f}, acc : {:.4f}'.format( epoch, itr, loss.item(), acc.item())) val_loss = evaluate(model, val_dataloader, name_to_index, index_to_name, device, args) # check for improved performance if val_loss < best_loss: since_best = 0 best_loss = val_loss # save weights logger.info('Saving model (epoch = {}, itr = {})'.format(epoch, itr)) save_model(model, args.model_dir, train_dataset[0][0].unsqueeze(0)) # save meta information ckpt_meta_path = os.path.join(args.model_dir, 'checkpoint') torch.save({ 'best_loss' : best_loss, 'epoch' : epoch, 'itr' : itr, 'optimizer' : optimizer.state_dict(), 'model' : model.state_dict() }, ckpt_meta_path) else: since_best += 1 if since_best >= hparams['early_stopping_epochs']: done = True logger.info('Early stopping...') logger.info('Training complete!') def evaluate(model, dataloader, name_to_index, index_to_name, device, args): """ Evaluate model """ model.eval() loss_fn = nn.CrossEntropyLoss(reduction='none') with torch.no_grad(): ct, i = 0, 0 all_losses, all_outputs, all_labels, all_sources = [], [], [], [] for batch in dataloader: if torch.isnan(batch[0]).any(): continue i += 1 ct += batch[0].size(0) features = batch[0].to(device) labels = torch.tensor([name_to_index[i] for i in batch[1]], dtype=torch.long, device=device) all_sources += batch[2] outputs = model(features) all_outputs.append(outputs.argmax(-1).numpy()) all_labels.append(labels.numpy()) loss = loss_fn(outputs, labels) all_losses.append(loss.numpy()) all_losses, all_outputs, all_labels = np.hstack(all_losses), np.hstack(all_outputs), np.hstack(all_labels) # per language f1 scores f1 = classwise_f1(all_outputs, all_labels, idx_to_label=index_to_name) metrics = pd.DataFrame() metrics['source'] = all_sources metrics['loss'] = all_losses metrics['predictions'] = all_outputs metrics['labels'] = all_labels if hparams['source_normalized_evaluation']: val_loss = metrics.groupby('source').apply(lambda x : x['loss'].mean()).mean() val_acc = metrics.groupby('source').apply( lambda x : (x['predictions'] == x['labels']).astype(float).mean()).mean() else: val_loss = metrics['loss'].mean() val_acc = (metrics['predictions'] == metrics['labels']).astype(float).mean() logger.info('Val - loss : {:.4f}, acc : {:.4f}, f1 : ({})'.format( val_loss, val_acc, ', '.join(['{} : {:.4f}'.format(k, v) for k, v in f1.items()]))) return val_loss if __name__ =='__main__': parser = argparse.ArgumentParser() parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) parser.add_argument('--validation', type=str, default=os.environ['SM_CHANNEL_VALIDATION']) hparams = _get_hparams() train(parser.parse_args(), hparams) sys.exit(0)