# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import logging import os import pandas as pd import pickle as pkl from sagemaker_containers import entry_point from sagemaker_xgboost_container.data_utils import get_dmatrix from sagemaker_xgboost_container import distributed import xgboost as xgb # source: https://github.com/aws/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/xgboost_abalone/xgboost_abalone_dist_script_mode.ipynb def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master): """Run xgb train on arguments given with rabit initialized. This is our rabit execution function. :param args_dict: Argument dictionary used to run xgb.train(). :param is_master: True if current node is master host in distributed training, or is running single node training job. Note that rabit_run will include this argument. """ booster = xgb.train(params=params, dtrain=dtrain, evals=evals, num_boost_round=num_boost_round) if is_master: model_location = model_dir + '/xgboost-model' pkl.dump(booster, open(model_location, 'wb')) logging.info("Stored trained model at {}".format(model_location)) if __name__ == '__main__': parser = argparse.ArgumentParser() # Hyperparameters are described here. parser.add_argument('--max_depth', type=int,) parser.add_argument('--eta', type=float) parser.add_argument('--gamma', type=int) parser.add_argument('--min_child_weight', type=int) parser.add_argument('--subsample', type=float) parser.add_argument('--verbosity', type=int) parser.add_argument('--objective', type=str) parser.add_argument('--num_round', type=int) parser.add_argument('--tree_method', type=str, default="exact") parser.add_argument('--predictor', type=str, default="cpu_predictor") # Sagemaker specific arguments. Defaults are set in the environment variables. parser.add_argument('--output_data_dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR')) parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION')) parser.add_argument('--sm_hosts', type=str, default=os.environ.get('SM_HOSTS')) parser.add_argument('--sm_current_host', type=str, default=os.environ.get('SM_CURRENT_HOST')) args, _ = parser.parse_known_args() # Get SageMaker host information from runtime environment variables sm_hosts = json.loads(args.sm_hosts) sm_current_host = args.sm_current_host dtrain = get_dmatrix(args.train, 'libsvm') dval = get_dmatrix(args.validation, 'libsvm') watchlist = [(dtrain, 'train'), (dval, 'validation')] if dval is not None else [(dtrain, 'train')] train_hp = { 'max_depth': args.max_depth, 'eta': args.eta, 'gamma': args.gamma, 'min_child_weight': args.min_child_weight, 'subsample': args.subsample, 'verbosity': args.verbosity, 'objective': args.objective, 'tree_method': args.tree_method, 'predictor': args.predictor, } xgb_train_args = dict( params=train_hp, dtrain=dtrain, evals=watchlist, num_boost_round=args.num_round, model_dir=args.model_dir) if len(sm_hosts) > 1: # Wait until all hosts are able to find each other entry_point._wait_hostname_resolution() # Execute training function after initializing rabit. distributed.rabit_run( exec_fun=_xgb_train, args=xgb_train_args, include_in_training=(dtrain is not None), hosts=sm_hosts, current_host=sm_current_host, update_rabit_args=True ) else: # If single node training, call training method directly. if dtrain: xgb_train_args['is_master'] = True _xgb_train(**xgb_train_args) else: raise ValueError("Training channel must have data to train model.") def model_fn(model_dir): """Deserialize and return fitted model. Note that this should have the same name as the serialized model in the _xgb_train method """ model_file = 'xgboost-model' booster = pkl.load(open(os.path.join(model_dir, model_file), 'rb')) return booster