# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from sagemakercv.layers import FrozenBatchNorm2d class FrozenBatchNorm2d_NHWC(torch.jit.ScriptModule): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__(self, n): super(FrozenBatchNorm2d_NHWC, self).__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) @torch.jit.script_method def forward(self, x): scale = self.weight * self.running_var.rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, 1, 1, -1) bias = bias.reshape(1, 1, 1, -1) return x * scale + bias