import torch from torch import nn import torch.nn.functional as F from functools import partial from collections import OrderedDict import math def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ConvBottleneck(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.seq = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True) ) def forward(self, dec, enc): x = torch.cat([dec, enc], dim=1) return self.seq(x) class UnetDecoderBlock(nn.Module): def __init__(self, in_channels, middle_channels, out_channels): super().__init__() self.layer = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): return self.layer(x) class ResNet(nn.Module): def __init__(self, block, layers, in_channels=3): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x def resnet34(**kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) return model def densenet121(pretrained=True, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) return model def densenet161(pretrained=True, **kwargs): r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) return model encoder_params = { 'resnet34': { 'filters': [64, 64, 128, 256, 512], 'decoder_filters': [64, 128, 256, 512], 'last_upsample': 64, 'init_op': partial(resnet34, in_channels=4) }, 'densenet161': {'filters': [96, 384, 768, 2112, 2208], 'decoder_filters': [64, 128, 256, 256], 'last_upsample': 64, 'url': None, 'init_op': densenet161 }, 'densenet121': {'filters': [64, 256, 512, 1024, 1024], 'decoder_filters': [64, 128, 256, 256], 'last_upsample': 64, 'url': None, 'init_op': densenet121 } } class AbstractModel(nn.Module): def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() @property def first_layer_params_name(self): return 'conv1' class EncoderDecoder(AbstractModel): def __init__(self, num_classes, num_channels=3, encoder_name='resnet34'): if not hasattr(self, 'first_layer_stride_two'): self.first_layer_stride_two = False if not hasattr(self, 'decoder_block'): self.decoder_block = UnetDecoderBlock if not hasattr(self, 'bottleneck_type'): self.bottleneck_type = ConvBottleneck self.filters = encoder_params[encoder_name]['filters'] self.decoder_filters = encoder_params[encoder_name].get( 'decoder_filters', self.filters[:-1]) self.last_upsample_filters = encoder_params[encoder_name].get( 'last_upsample', self.decoder_filters[0]//2) super().__init__() self.num_channels = num_channels self.num_classes = num_classes self.bottlenecks = nn.ModuleList([ self.bottleneck_type(self.filters[-i - 2] + f, f) for i, f in enumerate(reversed(self.decoder_filters[:]))]) self.decoder_stages = nn.ModuleList([ self.get_decoder(idx) for idx in range(0, len(self.decoder_filters))]) if self.first_layer_stride_two: self.last_upsample = self.decoder_block(self.decoder_filters[0], self.last_upsample_filters, self.last_upsample_filters) self.final = self.make_final_classifier( self.last_upsample_filters if self.first_layer_stride_two else self.decoder_filters[0], num_classes) self._initialize_weights() encoder = encoder_params[encoder_name]['init_op']() self.encoder_stages = nn.ModuleList([self.get_encoder(encoder, idx) for idx in range(len(self.filters))]) # noinspection PyCallingNonCallable def forward(self, x): enc_results = [] for stage in self.encoder_stages: x = stage(x) enc_results.append(torch.cat(x, dim=1) if isinstance(x, tuple) else x.clone()) last_dec_out = enc_results[-1] x = last_dec_out for idx, bottleneck in enumerate(self.bottlenecks): rev_idx = - (idx + 1) x = self.decoder_stages[rev_idx](x) x = bottleneck(x, enc_results[rev_idx - 1]) if self.first_layer_stride_two: x = self.last_upsample(x) f = self.final(x) return f def get_decoder(self, layer): in_channels = self.filters[layer + 1] if layer + 1 == len( self.decoder_filters ) else self.decoder_filters[layer + 1] return self.decoder_block(in_channels, self.decoder_filters[layer], self.decoder_filters[max(layer, 0)]) def make_final_classifier(self, in_filters, num_classes): return nn.Sequential( nn.Conv2d(in_filters, num_classes, 1, padding=0) ) def get_encoder(self, encoder, layer): raise NotImplementedError @property def first_layer_params(self): return _get_layers_params([self.encoder_stages[0]]) @property def layers_except_first_params(self): layers = get_slice(self.encoder_stages, 1, -1) + [self.bottlenecks, self.decoder_stages, self.final] return _get_layers_params(layers) class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = drop_rate def forward(self, x): new_features = super(_DenseLayer, self).forward(x) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return torch.cat([x, new_features], 1) class _DenseBlock(nn.Sequential): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) self.add_module('denselayer%d' % (i + 1), layer) class _Transition(nn.Sequential): def __init__(self, num_input_features, num_output_features): super(_Transition, self).__init__() self.add_module('norm', nn.BatchNorm2d(num_input_features)) self.add_module('relu', nn.ReLU(inplace=True)) self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" `_ Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): super(DenseNet, self).__init__() # First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(4, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ])) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm self.features.add_module('norm5', nn.BatchNorm2d(num_features)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal(m.weight.data) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) out = self.classifier(out) return out class SelimSef_SpaceNet4_ResNet34UNet(EncoderDecoder): def __init__(self): self.first_layer_stride_two = True super().__init__(3, 4, 'resnet34') def get_encoder(self, encoder, layer): if layer == 0: return nn.Sequential( encoder.conv1, encoder.bn1, encoder.relu) elif layer == 1: return nn.Sequential( encoder.maxpool, encoder.layer1) elif layer == 2: return encoder.layer2 elif layer == 3: return encoder.layer3 elif layer == 4: return encoder.layer4 class SelimSef_SpaceNet4_DenseNet121Unet(EncoderDecoder): def __init__(self): self.first_layer_stride_two = True super().__init__(3, 3, 'densenet121') def get_encoder(self, encoder, layer): if layer == 0: return nn.Sequential( encoder.features.conv0, # conv encoder.features.norm0, # bn encoder.features.relu0 # relu ) elif layer == 1: return nn.Sequential(encoder.features.pool0, encoder.features.denseblock1) elif layer == 2: return nn.Sequential(encoder.features.transition1, encoder.features.denseblock2) elif layer == 3: return nn.Sequential(encoder.features.transition2, encoder.features.denseblock3) elif layer == 4: return nn.Sequential(encoder.features.transition3, encoder.features.denseblock4, encoder.features.norm5, nn.ReLU()) class SelimSef_SpaceNet4_DenseNet161Unet(EncoderDecoder): def __init__(self): self.first_layer_stride_two = True super().__init__(3, 3, 'densenet161') def get_encoder(self, encoder, layer): if layer == 0: return nn.Sequential( encoder.features.conv0, # conv encoder.features.norm0, # bn encoder.features.relu0 # relu ) elif layer == 1: return nn.Sequential(encoder.features.pool0, encoder.features.denseblock1) elif layer == 2: return nn.Sequential(encoder.features.transition1, encoder.features.denseblock2) elif layer == 3: return nn.Sequential(encoder.features.transition2, encoder.features.denseblock3) elif layer == 4: return nn.Sequential(encoder.features.transition3, encoder.features.denseblock4, encoder.features.norm5, nn.ReLU()) def _get_layers_params(layers): return sum((list(l.parameters()) for l in layers), []) def get_slice(features, start, end): if end == -1: end = len(features) return [features[i] for i in range(start, end)] SelimSef_SpaceNet4_DenseNet161UNet = SelimSef_SpaceNet4_DenseNet161Unet SelimSef_SpaceNet4_DenseNet121UNet = SelimSef_SpaceNet4_DenseNet121Unet