# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import json import os import pickle as pkl import numpy as np import sagemaker_xgboost_container.encoder as xgb_encoders def model_fn(model_dir): """ Deserialize and return fitted model. """ model_file = "xgboost-model" booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) return booster def input_fn(request_body, request_content_type): """ The SageMaker XGBoost model server receives the request data body and the content type, and invokes the `input_fn`. Return a DMatrix (an object that can be passed to predict_fn). """ if request_content_type == "text/libsvm": return xgb_encoders.libsvm_to_dmatrix(request_body) else: raise ValueError("Content type {} is not supported.".format(request_content_type)) def predict_fn(input_data, model): """ SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. Return a two-dimensional NumPy array where the first columns are predictions and the remaining columns are the feature contributions (SHAP values) for that prediction. """ prediction = model.predict(input_data) feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) output = np.hstack((prediction[:, np.newaxis], feature_contribs)) return output def output_fn(predictions, content_type): """ After invoking predict_fn, the model server invokes `output_fn`. """ if content_type == "text/csv": return ",".join(str(x) for x in predictions[0]) else: raise ValueError("Content type {} is not supported.".format(content_type))