#!/usr/bin/env python import argparse import logging import os import joblib import pandas as pd from sklearn.linear_model import LogisticRegression logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def parse_args(): """ Load training job parameters. # SageMaker passes parameters to training job by environment variables (see [1]). # Here we add argument parser to ease passing parameters during local debugging. # [1] https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md :return: train_dir: training data folder, model_dir: model output folder """ parser = argparse.ArgumentParser() # "train" is the name of training dataset channel specified to the fit() method on client side. parser.add_argument("--train_dir", type=str, default="/opt/ml/input/data/train") parser.add_argument("--model_dir", type=str, default="/opt/ml/model") args = parser.parse_args() return args def load_train_data(train_dir): """ Load training data. :param train_dir: Train data dir :return: A tuple of (features DataFrame, label Series) """ data_files = [os.path.join(train_dir, file) for file in os.listdir(train_dir)] if len(data_files) == 0: error_message = f"There are no files in {train_dir}. Please specify channel 'train' correctly." logger.exception(error_message) raise ValueError(error_message) # assume that data files are CSV without header raw_data = [pd.read_csv(file, header=None) for file in data_files] train_data = pd.concat(raw_data) # assume that the first column is the label column label = train_data.iloc[:, 0] features = train_data.iloc[:, 1:] return features, label def train(features, label): """ Train the model :param features: features DataFrame :param label: label Series :return: The fitted estimator """ estimator = LogisticRegression(solver="newton-cg") estimator.fit(features, label) return estimator def dump_model_file(estimator, model_dir): """ Dump the estimator to a model file. :param estimator: The fitted estimator :param model_dir: Directory to store the model file :return: None """ # Note that the same model file name is used by `serve` to load the model. joblib.dump(estimator, os.path.join(model_dir, "model.joblib")) def main(): args = parse_args() features, label = load_train_data(args.train_dir) estimator = train(features, label) dump_model_file(estimator, args.model_dir) if __name__ == "__main__": main()