#!/usr/bin/env python3 # A sample training component that trains a simple scikit-learn decision tree model. # This implementation works in File mode and makes no assumptions about the input file names. # Input is specified as CSV with a data point in each row and the labels in the first column. from __future__ import print_function import os import json import pickle import sys import traceback import numpy as np import pandas as pd # TODO: Import necessary libraries # These are the paths to where SageMaker mounts interesting things in your container. prefix = '/opt/ml/' input_path = prefix + 'input/data' output_path = os.path.join(prefix, 'output') model_path = os.path.join(prefix, 'model') param_path = os.path.join(prefix, 'input/config/hyperparameters.json') # This algorithm has a single channel of input data called 'training'. # Since we run in File mode, the input files are copied to the directory specified here. channel_name='train' training_path = os.path.join(input_path, channel_name) if not os.path.exists(training_path): training_path = os.path.join(input_path, 'training') # The function to execute the training. def train(): print('Starting the training.') try: # TODO: Read in any hyperparameters that the user passed with the training job # TODO: Original source of training data, which the trainer would defult to if no train channel is specified #TODO: Load training data, create model and fit model to data # Save the model artifacts and character indices under /opt/ml/model print('Training complete.') except Exception as e: # Write out an error file. This will be returned as the failureReason in the # DescribeTrainingJob result. trc = traceback.format_exc() with open(os.path.join(output_path, 'failure'), 'w') as s: s.write('Exception during training: ' + str(e) + '\n' + trc) # Printing this causes the exception to be in the training job logs, as well. print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr) # A non-zero exit code causes the training job to be marked as Failed. sys.exit(255) if __name__ == '__main__': tree() train() # A zero exit code causes the job to be marked a Succeeded. sys.exit(0)