layout_data.models.fpn.fpn_head 源代码

import torch.nn as nn
import torch.nn.functional as F


[文档]class Conv3x3GNReLU(nn.Module): def __init__(self, in_channels, out_channels, upsample=False): super().__init__() self.upsample = upsample self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False), nn.GroupNorm(32, out_channels), nn.ReLU(inplace=True), )
[文档] def forward(self, x, size): x = self.block(x) if self.upsample: x = F.interpolate(x, size=size, mode="bilinear", align_corners=True) return x
[文档]class FPNBlock(nn.Module): def __init__(self, pyramid_channels, skip_channels): super().__init__() self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
[文档] def forward(self, x): x, skip = x # x = F.interpolate(x, scale_factor=2, mode='nearest') x = F.interpolate(x, size=skip.size()[-2:], mode="bilinear", align_corners=True) skip = self.skip_conv(skip) x = x + skip return x
[文档]class SegmentationBlock(nn.Module): def __init__(self, in_channels, out_channels, n_upsamples=0): super().__init__() self.blocks = [ Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples)) ] if n_upsamples > 1: for _ in range(1, n_upsamples): self.blocks.append( Conv3x3GNReLU(out_channels, out_channels, upsample=True)) self.blocks_name = [] for i, block in enumerate(self.blocks): self.add_module("Block_{}".format(i), block) self.blocks_name.append("Block_{}".format(i)) # self.block = nn.Sequential(*self.blocks)
[文档] def forward(self, x, sizes=[]): for i, block_name in enumerate(self.blocks_name): x = getattr(self, block_name)(x, sizes[i]) return x
[文档]class FPNDecoder(nn.Module): def __init__( self, encoder_channels, pyramid_channels=256, segmentation_channels=128, final_upsampling=4, final_channels=1, dropout=0.2, ): super().__init__() self.final_upsampling = final_upsampling self.conv1 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=(1, 1)) self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) self.s5 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=3) self.s4 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=2) self.s3 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=1) self.s2 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=0) self.dropout = nn.Dropout2d(p=dropout, inplace=True) self.final_conv = nn.Conv2d(segmentation_channels, final_channels, kernel_size=1, padding=0)
[文档] def forward(self, x): # c5, c4, c3, c2, _ = x _, c2, c3, c4, c5 = x p5 = self.conv1(c5) p4 = self.p4([p5, c4]) p3 = self.p3([p4, c3]) p2 = self.p2([p3, c2]) s5 = self.s5(p5, sizes=[c4.size()[-2:], c3.size()[-2:], c2.size()[-2:]]) s4 = self.s4(p4, sizes=[c3.size()[-2:], c2.size()[-2:]]) s3 = self.s3(p3, sizes=[c2.size()[-2:]]) s2 = self.s2(p2, sizes=[c2.size()[-2:]]) x = s5 + s4 + s3 + s2 x = self.dropout(x) x = self.final_conv(x) if self.final_upsampling is not None and self.final_upsampling > 1: x = F.interpolate( x, scale_factor=self.final_upsampling, mode="bilinear", align_corners=True, ) return x