import torch
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
is_norm, kaiming_init)
from torch import nn as nn
from mmdet.models import NECKS
[docs]@NECKS.register_module()
class SECONDFPN(nn.Module):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
"""
def __init__(self,
in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False)):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__()
assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels
self.out_channels = out_channels
deblocks = []
for i, out_channel in enumerate(out_channels):
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=upsample_strides[i],
stride=upsample_strides[i])
deblock = nn.Sequential(upsample_layer,
build_norm_layer(norm_cfg, out_channel)[1],
nn.ReLU(inplace=True))
deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks)
[docs] def init_weights(self):
"""Initialize weights of FPN."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif is_norm(m):
constant_init(m, 1)
[docs] def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
Returns:
list[torch.Tensor]: Multi-level feature maps.
"""
assert len(x) == len(self.in_channels)
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
if len(ups) > 1:
out = torch.cat(ups, dim=1)
else:
out = ups[0]
return [out]