import argparse import json import os import pickle import random import tempfile import urllib.request import xgboost from smdebug import SaveConfig from smdebug.xgboost import Hook def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--max_depth", type=int, default=5) parser.add_argument("--eta", type=float, default=0.2) parser.add_argument("--gamma", type=int, default=4) parser.add_argument("--min_child_weight", type=int, default=6) parser.add_argument("--subsample", type=float, default=0.8) parser.add_argument("--verbosity", type=int, default=0) parser.add_argument("--objective", type=str, default="binary:logistic") parser.add_argument("--num_round", type=int, default=50) parser.add_argument("--smdebug_path", type=str, default=None) parser.add_argument("--smdebug_frequency", type=int, default=1) parser.add_argument("--smdebug_collections", type=str, default='metrics') parser.add_argument("--output_uri", type=str, default="/opt/ml/output/tensors", help="S3 URI of the bucket where tensor data will be stored.") parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION')) parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) args = parser.parse_args() return args def create_smdebug_hook(out_dir, train_data=None, validation_data=None, frequency=1, collections=None,): save_config = SaveConfig(save_interval=frequency) hook = Hook( out_dir=out_dir, train_data=train_data, validation_data=validation_data, save_config=save_config, include_collections=collections, ) return hook def main(): args = parse_args() train, validation = args.train, args.validation parse_csv = "?format=csv&label_column=0" dtrain = xgboost.DMatrix(train+parse_csv) dval = xgboost.DMatrix(validation+parse_csv) watchlist = [(dtrain, "train"), (dval, "validation")] params = { "max_depth": args.max_depth, "eta": args.eta, "gamma": args.gamma, "min_child_weight": args.min_child_weight, "subsample": args.subsample, "verbosity": args.verbosity, "objective": args.objective} # The output_uri is a the URI for the s3 bucket where the metrics will be # saved. output_uri = ( args.smdebug_path if args.smdebug_path is not None else args.output_uri ) collections = ( args.smdebug_collections.split(',') if args.smdebug_collections is not None else None ) hook = create_smdebug_hook( out_dir=output_uri, frequency=args.smdebug_frequency, collections=collections, train_data=dtrain, validation_data=dval, ) bst = xgboost.train( params=params, dtrain=dtrain, evals=watchlist, num_boost_round=args.num_round, callbacks=[hook]) if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) model_location = os.path.join(args.model_dir, 'xgboost-model') pickle.dump(bst, open(model_location, 'wb')) if __name__ == "__main__": main() def model_fn(model_dir): """Load a model. For XGBoost Framework, a default function to load a model is not provided. Users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved. Returns: A XGBoost model. XGBoost model format type. """ model_files = (file for file in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, file))) model_file = next(model_files) try: booster = pickle.load(open(os.path.join(model_dir, model_file), 'rb')) format = 'pkl_format' except Exception as exp_pkl: try: booster = xgboost.Booster() booster.load_model(os.path.join(model_dir, model_file)) format = 'xgb_format' except Exception as exp_xgb: raise ModelLoadInferenceError("Unable to load model: {} {}".format(str(exp_pkl), str(exp_xgb))) booster.set_param('nthread', 1) return booster, format