# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) 2018-2019 NVIDIA CORPORATION. All rights reserved. """ Miscellaneous utility functions """ import torch from torch import nn from torch.nn import functional as F from sagemakercv.layers import Conv2d from sagemakercv.layers.nhwc.misc import Conv2d_NHWC from sagemakercv.core.poolers import Pooler from sagemakercv.layers.nhwc import init def get_group_gn(dim, dim_per_gp, num_groups): """get number of groups used by GroupNorm, based on number of channels.""" assert dim_per_gp == -1 or num_groups == -1, \ "GroupNorm: can only specify G or C/G." if dim_per_gp > 0: assert dim % dim_per_gp == 0, \ "dim: {}, dim_per_gp: {}".format(dim, dim_per_gp) group_gn = dim // dim_per_gp else: assert dim % num_groups == 0, \ "dim: {}, num_groups: {}".format(dim, num_groups) group_gn = num_groups return group_gn def group_norm(out_channels, cfg, affine=True, divisor=1): out_channels = out_channels // divisor dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5 return torch.nn.GroupNorm( get_group_gn(out_channels, dim_per_gp, num_groups), out_channels, eps, affine ) def make_conv3x3( in_channels, out_channels, dilation=1, stride=1, use_gn=False, use_relu=False, kaiming_init=True, nhwc=False ): conv = Conv2d_NHWC if nhwc else Conv2d conv = conv( in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False if use_gn else True ) # Caffe2 implementation uses XavierFill, which in fact # corresponds to kaiming_uniform_ in PyTorch if kaiming_init: init.kaiming_normal_( conv.weight, mode="fan_out", nonlinearity="relu", nhwc=nhwc ) else: init.normal_(conv.weight, std=0.01, nhwc=nhwc) if not use_gn: nn.init.constant_(conv.bias, 0) module = [conv,] if use_gn: module.append(group_norm(out_channels)) if use_relu: module.append(nn.ReLU(inplace=True)) if len(module) > 1: return nn.Sequential(*module) return conv def make_fc(dim_in, hidden_dim, use_gn=False): ''' Caffe2 implementation uses XavierFill, which in fact corresponds to kaiming_uniform_ in PyTorch ''' if use_gn: fc = nn.Linear(dim_in, hidden_dim, bias=False) nn.init.kaiming_uniform_(fc.weight, a=1) return nn.Sequential(fc, group_norm(hidden_dim)) fc = nn.Linear(dim_in, hidden_dim) nn.init.kaiming_uniform_(fc.weight, a=1) nn.init.constant_(fc.bias, 0) return fc def conv_with_kaiming_uniform(use_gn=False, use_relu=False): def make_conv( in_channels, out_channels, kernel_size, stride=1, dilation=1, nhwc=False ): conv = Conv2d_NHWC if nhwc else Conv2d conv = conv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=dilation * (kernel_size - 1) // 2, dilation=dilation, bias=False if use_gn else True ) # Caffe2 implementation uses XavierFill, which in fact # corresponds to kaiming_uniform_ in PyTorch init.kaiming_uniform_(conv.weight, a=1, nhwc=nhwc) # TODO: need to fix GN issue for NHWC if not use_gn: nn.init.constant_(conv.bias, 0) module = [conv,] if use_gn: module.append(group_norm(out_channels)) if use_relu: module.append(nn.ReLU(inplace=True)) if len(module) > 1: return nn.Sequential(*module) return conv return make_conv