# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from sagemakercv.detection import registry from torch import nn @registry.ROI_BOX_PREDICTOR.register("FastRCNNPredictor") class FastRCNNPredictor(nn.Module): def __init__(self, config, pretrained=None): super(FastRCNNPredictor, self).__init__() stage_index = 4 stage2_relative_factor = 2 ** (stage_index - 1) res2_out_channels = config.MODEL.RESNETS.RES2_OUT_CHANNELS num_inputs = res2_out_channels * stage2_relative_factor num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7) self.cls_score = nn.Linear(num_inputs, num_classes) num_bbox_reg_classes = 2 if config.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes self.bbox_pred = nn.Linear(num_inputs, num_bbox_reg_classes * 4) nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) nn.init.constant_(self.cls_score.bias, 0) nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001) nn.init.constant_(self.bbox_pred.bias, 0) def forward(self, x): x = self.avgpool(x) x = x.view(x.size(0), -1) cls_logit = self.cls_score(x) bbox_pred = self.bbox_pred(x) return cls_logit, bbox_pred @registry.ROI_BOX_PREDICTOR.register("FPNPredictor") class FPNPredictor(nn.Module): def __init__(self, cfg): super(FPNPredictor, self).__init__() num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM self.cls_score = nn.Linear(representation_size, num_classes) num_bbox_reg_classes = 2 if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) nn.init.normal_(self.cls_score.weight, std=0.01) nn.init.normal_(self.bbox_pred.weight, std=0.001) for l in [self.cls_score, self.bbox_pred]: nn.init.constant_(l.bias, 0) def forward(self, x): scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) return scores, bbox_deltas @registry.ROI_BOX_PREDICTOR.register("CascadeFPNPredictor") class CascadeFPNPredictor(nn.Module): def __init__(self, cfg, class_agnostic=False): super(CascadeFPNPredictor, self).__init__() num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM self.cls_score = nn.Linear(representation_size, num_classes) num_bbox_reg_classes = 2 if class_agnostic else num_classes self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) nn.init.normal_(self.cls_score.weight, std=0.01) nn.init.normal_(self.bbox_pred.weight, std=0.001) for l in [self.cls_score, self.bbox_pred]: nn.init.constant_(l.bias, 0) def forward(self, x): scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) return scores, bbox_deltas def make_roi_box_predictor(cfg): func = registry.ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR] return func(cfg) def make_cascade_box_predictor(cfg): preditors = nn.ModuleList() for stage in range(cfg.MODEL.ROI_HEADS.CASCADE.STAGES): is_class_agnostic = stage != cfg.MODEL.ROI_HEADS.CASCADE.STAGES-1 \ if not cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else True preditors.append(CascadeFPNPredictor(cfg, class_agnostic=is_class_agnostic)) return preditors