import glob import os import sys sys.path.append("/mask-rcnn-tensorflow/MaskRCNN") from itertools import groupby from threading import Lock import cv2 import numpy as np from config import config as cfg from config import finalize_configs from dataset import DetectionDataset from model.generalized_rcnn import ResNetFPNModel from tensorpack.predict.base import OfflinePredictor from tensorpack.predict.config import PredictConfig from tensorpack.tfutils.sessinit import get_model_loader class MaskRCNNService: lock = Lock() predictor = None # class method to load trained model and create an offline predictor @classmethod def get_predictor(cls): """load trained model""" with cls.lock: # check if model is already loaded if cls.predictor: return cls.predictor os.environ["TENSORPACK_FP16"] = "true" # create a mask r-cnn model mask_rcnn_model = ResNetFPNModel(True) try: model_dir = os.environ["SM_MODEL_DIR"] except KeyError: model_dir = "/opt/ml/model" try: resnet_arch = os.environ["RESNET_ARCH"] except KeyError: resnet_arch = "resnet50" # file path to previoulsy trained mask r-cnn model latest_trained_model = "" model_search_path = os.path.join(model_dir, "model-*.index") for model_file in glob.glob(model_search_path): if model_file > latest_trained_model: latest_trained_model = model_file trained_model = latest_trained_model print(f"Using model: {trained_model}") # fixed resnet50 backbone weights cfg.MODE_FPN = True cfg.MODE_MASK = True if resnet_arch == "resnet101": cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 23, 3] else: cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] cfg_prefix = "CONFIG__" for key, value in dict(os.environ).items(): if key.startswith(cfg_prefix): attr_name = key[len(cfg_prefix) :] attr_name = attr_name.replace("__", ".") value = eval(value) print(f"update config: {attr_name}={value}") nested_var = cfg attr_list = attr_name.split(".") for attr in attr_list[0:-1]: nested_var = getattr(nested_var, attr) setattr(nested_var, attr_list[-1], value) # calling detection dataset gets the number of coco categories # and saves in the configuration DetectionDataset() finalize_configs(is_training=False) # Create an inference model # PredictConfig takes a model, input tensors and output tensors cls.predictor = OfflinePredictor( PredictConfig( model=mask_rcnn_model, session_init=get_model_loader(trained_model), input_names=["images", "orig_image_dims"], output_names=[ "generate_{}_proposals_topk_per_image/boxes".format( "fpn" if cfg.MODE_FPN else "rpn" ), "generate_{}_proposals_topk_per_image/scores".format( "fpn" if cfg.MODE_FPN else "rpn" ), "fastrcnn_all_scores", "output/boxes", "output/scores", "output/labels", "output/masks", ], ) ) return cls.predictor # class method to predict @classmethod def predict(cls, img=None, img_id=None, rpn=False, score_threshold=0.8, mask_threshold=0.5): ( rpn_boxes, rpn_scores, all_scores, final_boxes, final_scores, final_labels, masks, ) = cls.predictor(np.expand_dims(img, axis=0), np.expand_dims(np.array(img.shape), axis=0)) predictions = {"img_id": str(img_id)} annotations = [] img_shape = (img.shape[0], img.shape[1]) for box, mask, score, category_id in zip(final_boxes, masks, final_scores, final_labels): a = {} b = box.tolist() a["bbox"] = [int(b[0]), int(b[1]), int(b[2] - b[0]), int(b[3] - b[1])] if round(score, 1) >= score_threshold: a["category_id"] = int(category_id) a["category_name"] = cfg.DATA.CLASS_NAMES[int(category_id)] b_mask = cls.get_binary_mask(img_shape, box, mask, threshold=mask_threshold) rle = cls.binary_mask_to_rle(b_mask) a["segmentation"] = rle annotations.append(a) predictions["annotations"] = annotations if rpn: predictions["rpn_boxes"] = rpn_boxes.tolist() predictions["rpn_scores"] = rpn_scores.tolist() predictions["all_scores"] = all_scores.tolist() return predictions @classmethod def get_binary_mask(cls, img_shape, box, mask, threshold=0.5): b_mask = np.zeros(shape=img_shape, dtype=np.uint8) box = box.astype(int) width = box[2] - box[0] height = box[3] - box[1] dim = (width, height) a_mask = (cv2.resize(mask, dim) > threshold).astype(np.uint8) b_mask[box[1] : box[3], box[0] : box[2]] = a_mask return b_mask @classmethod def binary_mask_to_rle(cls, binary_mask): rle = {"counts": [], "size": list(binary_mask.shape)} counts = rle.get("counts") for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order="C"))): if i == 0 and value == 1: counts.append(0) counts.append(len(list(elements))) return rle # create predictor MaskRCNNService.get_predictor() import base64 import json import tempfile from flask import Flask, Response, request app = Flask(__name__) @app.route("/ping", methods=["GET"]) def health_check(): """Determine if the container is working and healthy. In this sample container, we declare it healthy if we can load the model successfully and crrate a predictor.""" health = MaskRCNNService.get_predictor() is not None # You can insert a health check here status = 200 if health else 404 return Response(response="\n", status=status, mimetype="application/json") @app.route("/invocations", methods=["POST"]) def inference(): if not request.is_json: result = {"error": "Content type is not application/json"} print(result) return Response(response=result, status=415, mimetype="application/json") try: content = request.get_json() img_id = content["img_id"] with tempfile.NamedTemporaryFile() as fh: img_data_string = content["img_data"] img_data_bytes = bytearray(img_data_string, encoding="utf-8") fh.write(base64.decodebytes(img_data_bytes)) fh.seek(0) img = cv2.imread(fh.name, cv2.IMREAD_COLOR) fh.close() rpn = False try: rpn = content["rpn"] except KeyError: pass score_threshold = 0.8 try: score_threshold = content["score_threshold"] except KeyError: pass mask_threshold = 0.5 try: mask_threshold = content["mask_threshold"] except KeyError: pass pred = MaskRCNNService.predict( img=img, img_id=img_id, rpn=rpn, score_threshold=score_threshold, mask_threshold=mask_threshold, ) return Response(response=json.dumps(pred), status=200, mimetype="application/json") except Exception as e: print(str(e)) result = {"error": "Internal server error"} return Response(response=result, status=500, mimetype="application/json")