import os import time import json import pickle as pkl import numpy as np from io import BytesIO import xgboost as xgb import sagemaker_xgboost_container.encoder as xgb_encoders NUM_FEATURES = 58 def model_fn(model_dir): """ Deserialize and return fitted model. """ model_file = "xgboost-model" model = xgb.Booster() model.load_model(os.path.join(model_dir, model_file)) return model def input_fn(request_body, request_content_type): """ The SageMaker XGBoost model server receives the request data body and the content type, and invokes the `input_fn`. Return a DMatrix (an object that can be passed to predict_fn). """ print("Content type: ", request_content_type) if request_content_type == "application/x-npy": stream = BytesIO(request_body) array = np.frombuffer(stream.getvalue()) array = array.reshape(int(len(array)/NUM_FEATURES), NUM_FEATURES) return xgb.DMatrix(array) elif request_content_type == "text/csv": return xgb_encoders.csv_to_dmatrix(request_body.rstrip("\n")) elif request_content_type == "text/libsvm": return xgb_encoders.libsvm_to_dmatrix(request_body) else: raise ValueError( "Content type {} is not supported.".format(request_content_type) ) def predict_fn(input_data, model): """ SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. Return a two-dimensional NumPy array (predictions and scores) """ start_time = time.time() y_probs = model.predict(input_data) print("--- Inference time: %s secs ---" % (time.time() - start_time)) y_preds = [1 if e >= 0.5 else 0 for e in y_probs] #feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) return np.vstack((y_preds, y_probs)) def output_fn(predictions, content_type="application/json"): """ After invoking predict_fn, the model server invokes `output_fn`. """ if content_type == "text/csv": return ','.join(str(x) for x in outputs) elif content_type == "application/json": outputs = json.dumps({ 'pred': predictions[0,:].tolist(), 'prob': predictions[1,:].tolist() }) return outputs else: raise ValueError("Content type {} is not supported.".format(content_type))