#!/opt/conda/bin/python from __future__ import print_function import copy import json import os import pickle import shutil import sys import traceback from glob import glob from fairseq import distributed_utils, options from train_driver import main # These are the paths to where SageMaker mounts interesting things in your container. prefix = '/opt/ml/' input_path = prefix + 'input/data' output_path = os.path.join(prefix, 'output') model_path = os.path.join(prefix, 'model') param_path = os.path.join(prefix, 'input/config/hyperparameters.json') # This algorithm has a single channel of input data called 'training'. Since we run in # File mode, the input files are copied to the directory specified here. channel_name='training' training_path = os.path.join(input_path, channel_name) # convert True/False into action argparse options _STORE_TRUE_ACTIONS = [ 'no-progress-bar', 'fp16', 'skip-invalid-size-inputs-valid-test', 'fix-batches-to-gpus', 'sentence-avg', 'reset-optimizer', 'reset-lr-scheduler', 'no-save', 'no-epoch-checkpoints', 'cpu', 'quiet', 'output-word-probs', 'output-word-stats', 'no-early-stop', 'unnormalized', 'no-beamable-mm', 'score-reference', 'sampling', 'print-alignment', ] def convert_args_dict_to_list(args_dict): args = [] for key, value in args_dict.items(): if key in _STORE_TRUE_ACTIONS: if value == 'True': args.append('--{}'.format(key)) else: args.append('--{}'.format(key)) args.append(value) return args # The function to execute the training. def train(): print('Starting the training.') try: # Read in any hyperparameters that the user passed with the training job with open(param_path, 'r') as tc: training_params = json.load(tc) print(training_params) training_args = convert_args_dict_to_list(training_params) print(training_args) # get args for FairSeq by converting the hyperparameters as if they # were command-line arguments argv_copy = copy.deepcopy(sys.argv) # some arguments are pre-defined for SageMaker sys.argv[1:] = [training_path] + training_args + ["--save-dir", model_path] # get options from FairSeq parser = options.get_training_parser() # get inference/serve params #options.add_dataset_args(parser, gen=True) options.add_generation_args(parser) options.add_interactive_args(parser) args = options.parse_args_and_arch(parser) # remove the modifications we did in the command-line arguments sys.argv = argv_copy print(args) # copy dictionaries to the model dir as they # are needed by the model during serving dict_files = glob(os.path.join(training_path, "dict*.txt")) for src in dict_files: shutil.copy(src, model_path) # copy bpecodes if available bpecodes_file = glob(os.path.join(training_path, "bpecodes")) for src in bpecodes_file: shutil.copy(src, model_path) # main training function with open('/opt/ml/input/config/resourceconfig.json', 'r') as f: resource_config = json.load(f) hosts = resource_config['hosts'] if len(hosts) > 1: from distributed_train import main as distributed_main distributed_main(args) elif args.distributed_world_size > 1: from multiprocessing_train import main as multiprocessing_main multiprocessing_main(args) else: main(args) print('Training complete.') except Exception as e: # Write out an error file. This will be returned as the failureReason in the # DescribeTrainingJob result. trc = traceback.format_exc() with open(os.path.join(output_path, 'failure'), 'w') as s: s.write('Exception during training: ' + str(e) + '\n' + trc) # Printing this causes the exception to be in the training job logs, as well. print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr) # A non-zero exit code causes the training job to be marked as Failed. sys.exit(255) if __name__ == '__main__': train() # A zero exit code causes the job to be marked a Succeeded. sys.exit(0)