import os import pickle as pkl import json import numpy as np import xgboost as xgb from sagemaker_containers.beta.framework import ( content_types, encoders, env, modules, transformer, worker, ) from sagemaker_xgboost_container import encoder as xgb_encoders def input_fn(input_data, content_type): if content_type == content_types.JSON: obj = json.loads(input_data) features = obj["instances"][0]["features"] array = np.array(features).reshape((1, -1)) return xgb.DMatrix(array) else: return xgb_encoders.decode(input_data, content_type) def model_fn(model_dir): model_file = model_dir + "/model.bin" model = pkl.load(open(model_file, "rb")) return model def output_fn(prediction, accept): pred_array_value = np.array(prediction) score = pred_array_value[0] if accept == "application/json": predicted_label = 1 if score > 0.5 else 0 return_value = { "predictions": [{"score": score.astype(float), "predicted_label": predicted_label}] } return worker.Response(json.dumps(return_value), mimetype=accept) elif accept == "text/csv": return_value = "yes" if score > 0.5 else "no" return worker.Response(encoders.encode(prediction, accept), mimetype=accept) else: raise RuntimeException("{} accept type is not supported.".format(accept))