# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) 2018-2019 NVIDIA CORPORATION. All rights reserved. """ Variant of the resnet module that takes cfg as an argument. Example usage. Strings may be specified in the config file. model = ResNet( "StemWithFixedBatchNorm", "BottleneckWithFixedBatchNorm", "ResNet50StagesTo4", ) OR: model = ResNet( "StemWithGN", "BottleneckWithGN", "ResNet50StagesTo4", ) Custom implementations may be written in user code and hooked in via the `register_*` functions. """ from collections import namedtuple import torch import torch.nn.functional as F from torch import nn from sagemakercv.layers.nhwc import MaxPool2d_NHWC, FrozenBatchNorm2d_NHWC from sagemakercv.layers.nhwc.misc import Conv2d_NHWC from sagemakercv.layers.nhwc import kaiming_uniform_ from sagemakercv.layers.nhwc import nchw_to_nhwc_transform, nhwc_to_nchw_transform from sagemakercv.layers import FrozenBatchNorm2d from sagemakercv.layers import Conv2d from sagemakercv.utils.registry import Registry from sagemakercv.layers.make_layers import group_norm # ResNet stage specification StageSpec = namedtuple( "StageSpec", [ "index", # Index of the stage, eg 1, 2, ..,. 5 "block_count", # Numer of residual blocks in the stage "return_features", # True => return the last feature map from this stage ], ) # ----------------------------------------------------------------------------- # Standard ResNet models # ----------------------------------------------------------------------------- # ResNet-50 (including all stages) ResNet50StagesTo5 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) ) # ResNet-50 up to stage 4 (excludes stage 5) ResNet50StagesTo4 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) ) # ResNet-101 (including all stages) ResNet101StagesTo5 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True)) ) # ResNet-101 up to stage 4 (excludes stage 5) ResNet101StagesTo4 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True)) ) # ResNet-50-FPN (including all stages) ResNet50FPNStagesTo5 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) ) # ResNet-101-FPN (including all stages) ResNet101FPNStagesTo5 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) ) # ResNet-152-FPN (including all stages) ResNet152FPNStagesTo5 = tuple( StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True)) ) class ResNet(nn.Module): def __init__(self, cfg): super(ResNet, self).__init__() # If we want to use the cfg in forward(), then we should make a copy # of it and store it for later use: # self.cfg = cfg.clone() # Translate string names to implementations stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] # Construct the stem module self.stem = stem_module(cfg) # Constuct the specified ResNet stages num_groups = cfg.MODEL.RESNETS.NUM_GROUPS width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS stage2_bottleneck_channels = num_groups * width_per_group stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS self.stages = [] self.return_features = {} # Get the tensor layout (NHWC vs NCHW) self.nhwc = cfg.OPT_LEVEL == "O4" for stage_spec in stage_specs: name = "layer" + str(stage_spec.index) stage2_relative_factor = 2 ** (stage_spec.index - 1) bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor out_channels = stage2_out_channels * stage2_relative_factor module = _make_stage( transformation_module, in_channels, bottleneck_channels, out_channels, stage_spec.block_count, num_groups, cfg.MODEL.RESNETS.STRIDE_IN_1X1, first_stride=int(stage_spec.index > 1) + 1, nhwc=self.nhwc ) in_channels = out_channels self.add_module(name, module) self.stages.append(name) self.return_features[name] = stage_spec.return_features self.has_fpn = "FPN" in cfg.MODEL.BACKBONE.CONV_BODY # Optionally freeze (requires_grad=False) parts of the backbone self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) def _freeze_backbone(self, freeze_at): if freeze_at < 0: return for stage_index in range(freeze_at): if stage_index == 0: m = self.stem # stage 0 is the stem else: m = getattr(self, "layer" + str(stage_index)) for p in m.parameters(): p.requires_grad = False def forward(self, x): outputs = [] # if self.nhwc: # x = nchw_to_nhwc_transform(x) x = self.stem(x) for stage_name in self.stages: x = getattr(self, stage_name)(x) if self.return_features[stage_name]: outputs.append(x) if self.nhwc and not self.has_fpn: for i, t in enumerate(outputs): outputs[i] = nhwc_to_nchw_transform(t) return outputs class ResNetHead(nn.Module): def __init__( self, block_module, stages, num_groups=1, width_per_group=64, stride_in_1x1=True, stride_init=None, res2_out_channels=256, dilation=1, nhwc=False ): super(ResNetHead, self).__init__() stage2_relative_factor = 2 ** (stages[0].index - 1) stage2_bottleneck_channels = num_groups * width_per_group out_channels = res2_out_channels * stage2_relative_factor in_channels = out_channels // 2 bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor block_module = _TRANSFORMATION_MODULES[block_module] self.stages = [] stride = stride_init for stage in stages: name = "layer" + str(stage.index) if not stride: stride = int(stage.index > 1) + 1 module = _make_stage( block_module, in_channels, bottleneck_channels, out_channels, stage.block_count, num_groups, stride_in_1x1, first_stride=stride, dilation=dilation, nhwc=nhwc ) stride = None self.add_module(name, module) self.stages.append(name) def forward(self, x): for stage in self.stages: x = getattr(self, stage)(x) return x def _make_stage( transformation_module, in_channels, bottleneck_channels, out_channels, block_count, num_groups, stride_in_1x1, first_stride, dilation=1, nhwc=False, ): blocks = [] stride = first_stride for _ in range(block_count): blocks.append( transformation_module( in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation=dilation, nhwc=nhwc ) ) stride = 1 in_channels = out_channels return nn.Sequential(*blocks) class Bottleneck(torch.nn.Module): __constants__ = ['downsample'] def __init__( self, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc=False ): super(Bottleneck, self).__init__() conv = Conv2d_NHWC if nhwc else Conv2d if in_channels != out_channels: down_stride = stride if dilation == 1 else 1 self.downsample = nn.Sequential( conv( in_channels, out_channels, kernel_size=1, stride=down_stride, bias=False ), norm_func(out_channels), ) for modules in [self.downsample,]: for l in modules.modules(): if isinstance(l, conv): kaiming_uniform_(l.weight, a=1, nhwc=nhwc) else: self.downsample = None if dilation > 1: stride = 1 # reset to be 1 # The original MSRA ResNet models have stride in the first 1x1 conv # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have # stride in the 3x3 conv stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) self.conv1 = conv( in_channels, bottleneck_channels, kernel_size=1, stride=stride_1x1, bias=False, ) self.bn1 = norm_func(bottleneck_channels) # TODO: specify init for the above self.conv2 = conv( bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride_3x3, padding=dilation, bias=False, groups=num_groups, dilation=dilation ) self.bn2 = norm_func(bottleneck_channels) self.conv3 = conv( bottleneck_channels, out_channels, kernel_size=1, bias=False ) self.bn3 = norm_func(out_channels) for l in [self.conv1, self.conv2, self.conv3,]: kaiming_uniform_(l.weight, a=1, nhwc=nhwc) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = F.relu(out) out = self.conv2(out) out = self.bn2(out) out = F.relu(out) out0 = self.conv3(out) out = self.bn3(out0) if self.downsample is not None: identity = self.downsample(x) out = out + identity out = out.relu() return out class BottleneckWithFixedBatchNorm(Bottleneck): def __init__( self, in_channels, bottleneck_channels, out_channels, num_groups=1, stride_in_1x1=True, stride=1, dilation=1, nhwc=False ): frozen_batch_norm = FrozenBatchNorm2d_NHWC if nhwc else FrozenBatchNorm2d super(BottleneckWithFixedBatchNorm, self).__init__( in_channels=in_channels, bottleneck_channels=bottleneck_channels, out_channels=out_channels, num_groups=num_groups, stride_in_1x1=stride_in_1x1, stride=stride, dilation=dilation, norm_func=frozen_batch_norm, nhwc=nhwc ) class BottleneckWithGN(Bottleneck): def __init__( self, in_channels, bottleneck_channels, out_channels, num_groups=1, stride_in_1x1=True, stride=1, dilation=1, nhwc=False ): super(BottleneckWithGN, self).__init__( in_channels=in_channels, bottleneck_channels=bottleneck_channels, out_channels=out_channels, num_groups=num_groups, stride_in_1x1=stride_in_1x1, stride=stride, dilation=dilation, norm_func=group_norm ) class BaseStem(torch.nn.Module): def __init__(self, cfg, norm_func): super(BaseStem, self).__init__() out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS self.nhwc = cfg.OPT_LEVEL == "O4" conv = Conv2d_NHWC if self.nhwc else Conv2d self.conv1 = conv( 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_func(out_channels) max_pool = MaxPool2d_NHWC if self.nhwc else nn.MaxPool2d self.max_pool = max_pool(kernel_size=3, stride=2, padding=1) for l in [self.conv1,]: kaiming_uniform_(l.weight, a=1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.max_pool(x) return x class StemWithFixedBatchNorm(BaseStem): def __init__(self, cfg): norm_func=FrozenBatchNorm2d_NHWC if cfg.OPT_LEVEL == "O4" else FrozenBatchNorm2d super(StemWithFixedBatchNorm, self).__init__( cfg, norm_func ) class StemWithGN(BaseStem): def __init__(self, cfg): super(StemWithGN, self).__init__(cfg, norm_func=group_norm) _TRANSFORMATION_MODULES = Registry({ "BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm, "BottleneckWithGN": BottleneckWithGN, }) _STEM_MODULES = Registry({ "StemWithFixedBatchNorm": StemWithFixedBatchNorm, "StemWithGN": StemWithGN, }) _STAGE_SPECS = Registry({ "R-50-C4": ResNet50StagesTo4, "R-50-C5": ResNet50StagesTo5, "R-101-C4": ResNet101StagesTo4, "R-101-C5": ResNet101StagesTo5, "R-50-FPN": ResNet50FPNStagesTo5, "R-50-FPN-RETINANET": ResNet50FPNStagesTo5, "R-101-FPN": ResNet101FPNStagesTo5, "R-101-FPN-RETINANET": ResNet101FPNStagesTo5, "R-152-FPN": ResNet152FPNStagesTo5, })