import argparse import json import os import pickle import random import tempfile import urllib.request import xgboost from smdebug import SaveConfig from smdebug.xgboost import Hook def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--max_depth", type=int, default=5) parser.add_argument("--eta", type=float, default=0.2) parser.add_argument("--gamma", type=int, default=4) parser.add_argument("--min_child_weight", type=int, default=6) parser.add_argument("--subsample", type=float, default=0.8) parser.add_argument("--silent", type=int, default=0) parser.add_argument("--objective", type=str, default="binary:logistic") parser.add_argument("--num_round", type=int, default=50) parser.add_argument("--smdebug_path", type=str, default=None) parser.add_argument("--smdebug_frequency", type=int, default=1) parser.add_argument("--smdebug_collections", type=str, default="metrics") parser.add_argument( "--output_uri", type=str, default="/opt/ml/output/tensors", help="S3 URI of the bucket where tensor data will be stored.", ) parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN")) parser.add_argument("--validation", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")) parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) args = parser.parse_args() return args def create_smdebug_hook( out_dir, train_data=None, validation_data=None, frequency=1, collections=None, ): save_config = SaveConfig(save_interval=frequency) hook = Hook( out_dir=out_dir, train_data=train_data, validation_data=validation_data, save_config=save_config, include_collections=collections, ) return hook def main(): args = parse_args() train, validation = args.train, args.validation parse_csv = "?format=csv&label_column=0" dtrain = xgboost.DMatrix(train + parse_csv) dval = xgboost.DMatrix(validation + parse_csv) watchlist = [(dtrain, "train"), (dval, "validation")] params = { "max_depth": args.max_depth, "eta": args.eta, "gamma": args.gamma, "min_child_weight": args.min_child_weight, "subsample": args.subsample, "silent": args.silent, "objective": args.objective, } # The output_uri is a the URI for the s3 bucket where the metrics will be # saved. output_uri = args.smdebug_path if args.smdebug_path is not None else args.output_uri collections = ( args.smdebug_collections.split(",") if args.smdebug_collections is not None else None ) hook = create_smdebug_hook( out_dir=output_uri, frequency=args.smdebug_frequency, collections=collections, train_data=dtrain, validation_data=dval, ) bst = xgboost.train( params=params, dtrain=dtrain, evals=watchlist, num_boost_round=args.num_round, callbacks=[hook], ) if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) model_location = os.path.join(args.model_dir, "xgboost-model") pickle.dump(bst, open(model_location, "wb")) if __name__ == "__main__": main() def model_fn(model_dir): """Load a model. For XGBoost Framework, a default function to load a model is not provided. Users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved. Returns: A XGBoost model. XGBoost model format type. """ model_files = ( file for file in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, file)) ) model_file = next(model_files) try: booster = pickle.load(open(os.path.join(model_dir, model_file), "rb")) format = "pkl_format" except Exception as exp_pkl: try: booster = xgboost.Booster() booster.load_model(os.path.join(model_dir, model_file)) format = "xgb_format" except Exception as exp_xgb: raise ModelLoadInferenceError( "Unable to load model: {} {}".format(str(exp_pkl), str(exp_xgb)) ) booster.set_param("nthread", 1) return booster, format