# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import print_function import argparse import json import logging import os import queue import threading import pandas as pd import pickle as pkl from sagemaker_containers import entry_point from sagemaker_xgboost_container.data_utils import get_dmatrix from sagemaker_xgboost_container import distributed from sagemaker_xgboost_container import checkpointing import xgboost as xgb logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master, checkpoint_path): """Run xgb train on arguments given with rabit initialized. This is our rabit execution function. :param args_dict: Argument dictionary used to run xgb.train(). :param is_master: True if current node is master host in distributed training, or is running single node training job. Note that rabit_run will include this argument. """ logging.info("params: {}, num_boost_round: {}, checkpoint_path: {}".format(params, num_boost_round, checkpoint_path)) train_args = dict( params=params, dtrain=dtrain, num_boost_round=num_boost_round, evals=evals ) booster = checkpointing.train(train_args, checkpoint_path) if is_master: model_location = model_dir + '/xgboost-model' pkl.dump(booster, open(model_location, 'wb')) logging.info("Stored trained model at {}".format(model_location)) if __name__ == '__main__': parser = argparse.ArgumentParser() # Hyperparameters are described here. parser.add_argument('--max_depth', type=int,) parser.add_argument('--eta', type=float) parser.add_argument('--gamma', type=int) parser.add_argument('--min_child_weight', type=int) parser.add_argument('--subsample', type=float) parser.add_argument('--objective', type=str) parser.add_argument('--num_round', type=int) # Sagemaker specific arguments. Defaults are set in the environment variables. parser.add_argument('--output_data_dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR')) parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) parser.add_argument("--checkpoint-path",type=str,default="/opt/ml/checkpoints",help="Path where checkpoints will be saved.") parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION')) parser.add_argument('--sm_hosts', type=str, default=os.environ.get('SM_HOSTS')) parser.add_argument('--sm_current_host', type=str, default=os.environ.get('SM_CURRENT_HOST')) args, _ = parser.parse_known_args() # Get SageMaker host information from runtime environment variables sm_hosts = json.loads(args.sm_hosts) sm_current_host = args.sm_current_host dtrain = get_dmatrix(args.train, 'libsvm') dval = get_dmatrix(args.validation, 'libsvm') watchlist = [(dtrain, 'train'), (dval, 'validation')] if dval is not None else [(dtrain, 'train')] train_hp = { 'max_depth': args.max_depth, 'eta': args.eta, 'gamma': args.gamma, 'min_child_weight': args.min_child_weight, 'subsample': args.subsample, 'objective': args.objective } xgb_train_args = dict( params=train_hp, dtrain=dtrain, evals=watchlist, num_boost_round=args.num_round, model_dir=args.model_dir, checkpoint_path=args.checkpoint_path) if len(sm_hosts) > 1: # Wait until all hosts are able to find each other entry_point._wait_hostname_resolution() # Execute training function after initializing rabit. distributed.rabit_run( exec_fun=_xgb_train, args=xgb_train_args, include_in_training=(dtrain is not None), hosts=sm_hosts, current_host=sm_current_host, update_rabit_args=True ) else: # If single node training, call training method directly. if dtrain: xgb_train_args['is_master'] = True _xgb_train(**xgb_train_args) else: raise ValueError("Training channel must have data to train model.") def model_fn(model_dir): """Deserialize and return fitted model. Note that this should have the same name as the serialized model in the _xgb_train method """ model_file = 'xgboost-model' booster = pkl.load(open(os.path.join(model_dir, model_file), 'rb')) return booster