#!/usr/bin/env python import argparse import csv import logging import os from enum import Enum import joblib import json from flask import Flask, Response, request # This container supports the same types as Clarify job does. class SupportedMimeTypes(Enum): CSV = "text/csv" JSONLINES = "application/jsonlines" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) supported_mime_types = [item.value for item in SupportedMimeTypes] # load model parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, default="/opt/ml/model/") args = parser.parse_args() model_file = os.path.join(args.model_dir, "model.joblib") estimator = joblib.load(model_file) # HTTP Server app = Flask(__name__) @app.route("/ping", methods=["GET"]) def ping(): return Response(response="\n", status=200) @app.route("/invocations", methods=["POST"]) def predict(): # input validation if request.content_type not in supported_mime_types: error_message = f"Unsupported content type '{request.content_type}'. Please use one of {supported_mime_types}" logger.exception(error_message) return Response(response=error_message, status=400) accept_type = request.accept_mimetypes.best_match(supported_mime_types) if accept_type not in supported_mime_types: error_message = f"Unsupported accept type. Please use one of {supported_mime_types}" logger.exception(error_message) return Response(response=error_message, status=400) # Clarify job may send batch request to the container for better efficiency, # i.e. the payload can have multiple lines and each is a sample. lines = request.data.decode().splitlines() if not lines: error_message = "Payload is empty. Please post at least one sample." logger.exception(error_message) return Response(response=error_message, status=400) # parse payload try: if request.content_type == SupportedMimeTypes.CSV.value: data = list(csv.reader(lines)) else: # SupportedMimeTypes.JSONLINES.value # "features" is a self-defined key in your Clarify job analysis config `predictor.content_template`. # It is used by the container to extract the list of features from a JSON line. The key is a contract # between the Clarify job and the container, here you can change it to something else, like "attributes", # but remember to update the `predictor.content_template` configuration accordingly. data = [json.loads(jsonline)["features"] for jsonline in lines] except ValueError as e: error_message = f"Parsing request payload failed with error '{e}'" logger.exception(error_message) return Response(response=error_message, status=400) # do prediction try: # Here get the probability score predictions = estimator.predict_proba(data)[:, 1] except ValueError as e: error_message = f"Prediction failed with error '{e}'" logger.exception(error_message) return Response(response=error_message, status=400) # format output if accept_type == SupportedMimeTypes.CSV.value: # output scores for CSV accept type predictions = [str(prediction) for prediction in predictions] else: # SupportedMimeTypes.JSONLINES.value # "predicted_label" and "score" are self-defined keys in your Clarify job analysis config `predictor.label` # and `predictor.probability`. They are used by the Clarify job to extract predictions from container response. # The keys are contracts between the Clarify job and the container, here you can change them to something else, # but remember to update the analysis configuration accordingly. predictions = [json.dumps({ "predicted_label": 1 if prediction > 0.5 else 0, "score": prediction}) for prediction in predictions] # For batch request, Clarify job expects the same number of result lines as the number of samples in the request. output = "\n".join(predictions) return Response(response=output, status=200, headers=[("Content-Type", accept_type)]) def main(): # start server app.run(host="0.0.0.0", port=8080) if __name__ == "__main__": main()