#!/usr/bin/env python import os import time import copy import datetime import json from pprint import pformat import logging import numpy as np from dataset import SoundDataset from dataloader import SoundDataLoader from config import ParameterSetting import pkbar from workflow import * from types import SimpleNamespace def main(): hp = { "csv_path": "/opt/ml/input/data/competition/meta_train.csv", "data_dir": "/opt/ml/input/data/competition/train", "save_root": "/opt/ml/result/", "model_file": "/opt/ml/model/final_model.pkl", "resume": None, "model_name":"VGGish", "val_split":0.1, "epochs":5, "batch_size":128, "optimizer":"adam", "scheduler":"steplr", "lr":0.0001, "num_class":6, "normalize":None, "preload": False, "spec_aug": False, "time_drop_width":64, "time_stripes_num":2, "freq_drop_width":8, "freq_stripes_num":2, "sr":8000, "nfft":200, "hop":80, "mel":64 } config_path = '/opt/ml/input/config/hyperparameters.json' if os.path.exists(config_path): hf = open(config_path, 'r') chyperparameters = json.load(hf) hp.update(chyperparameters) import tarfile pretrained_model_path = '/opt/ml/input/data/model/model.tar.gz' if os.path.exists(pretrained_model_path): tar = tarfile.open(pretrained_model_path, "r:gz") tar.extractall() tar.close() hp['resume'] = 'final_model.pkl' args = SimpleNamespace(**hp) logging.basicConfig(level=logging.INFO) logger.info("Arguments: %s", pformat(args)) ################## # config setting # ################## params = ParameterSetting(args.csv_path, args.data_dir, args.save_root, args.model_file, args.model_name,float(args.val_split),int(args.epochs), int(args.batch_size), float(args.lr), int(args.num_class),int(args.time_drop_width), int(args.time_stripes_num),int(args.freq_drop_width), int(args.freq_stripes_num),int(args.sr), int(args.nfft), int(args.hop), int(args.mel), args.resume, args.normalize, args.preload, args.spec_aug, args.optimizer, args.scheduler) if not os.path.exists(params.save_root): os.mkdir(params.save_root) print("create folder: {}".format(params.save_root)) if not os.path.exists(os.path.join(params.save_root, 'snapshots')): os.mkdir(os.path.join(params.save_root, 'snapshots')) if not os.path.exists(os.path.join(params.save_root, 'log')): os.mkdir(os.path.join(params.save_root, 'log')) ################### # model preparing # ################### model = prepare_model(params) ################## # data preparing # ################## print("Preparing training/validation data...") dataset = SoundDataset(params) train_dataloader = SoundDataLoader(dataset, batch_size=params.batch_size, shuffle=True, validation_split=params.val_split, pin_memory=True) val_dataloader = train_dataloader.split_validation() dataloaders = {'train': train_dataloader, 'val': val_dataloader} dataset_sizes = {'train': len(train_dataloader.sampler), 'val': len(train_dataloader.valid_sampler)} if len(train_dataloader.sampler) < 50: dataloaders = {'train': train_dataloader, 'val': train_dataloader} dataset_sizes = {'train': len(train_dataloader.sampler), 'val': len(train_dataloader.sampler)} print("train size: {}, val size: {}".format(dataset_sizes['train'], dataset_sizes['val'])) ################## # model training # ################## # start to train the model train_model(model, params, dataloaders, dataset_sizes) if __name__ == '__main__': main()