DAMO-YOLO的Neck( Efficient RepGFPN)详解

[复制链接]
查看722 | 回复0 | 2023-8-23 11:59:29 | 显示全部楼层 |阅读模式

 这个图是有点题目的,在GiraffeNeckV2代码中只有了5个Fusion Block(图中有6个)
https://github.com/tinyvision/DAMO-YOLO/blob/master/damo/base_models/necks/giraffe_fpn_btn.py
代码中只有5个CSPStage
所以我自己画了一个总体图,在github上提了个issue,得到了原作者的肯定
I think the pictures in your paper are not rigorous in several places · Issue #91 · tinyvision/DAMO-YOLO · GitHub

 
想要看懂Neck部分,只必要看懂Fusion Block在做什么就行了,其他部分和PAN差不太多
  1. class CSPStage(nn.Module):   
  2.     def __init__(self,
  3.                  block_fn,
  4.                  ch_in,
  5.                  ch_hidden_ratio,
  6.                  ch_out,
  7.                  n,
  8.                  act='swish',
  9.                  spp=False):
  10.         super(CSPStage, self).__init__()
  11.         split_ratio = 2
  12.         ch_first = int(ch_out // split_ratio)
  13.         ch_mid = int(ch_out - ch_first)
  14.         self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)
  15.         self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)
  16.         self.convs = nn.Sequential()
  17.         next_ch_in = ch_mid
  18.         for i in range(n):
  19.             if block_fn == 'BasicBlock_3x3_Reverse':
  20.                 self.convs.add_module(
  21.                     str(i),
  22.                     BasicBlock_3x3_Reverse(next_ch_in,
  23.                                            ch_hidden_ratio,
  24.                                            ch_mid,
  25.                                            act=act,
  26.                                            shortcut=True))
  27.             else:
  28.                 raise NotImplementedError
  29.             if i == (n - 1) // 2 and spp:
  30.                 self.convs.add_module(
  31.                     'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
  32.             next_ch_in = ch_mid
  33.         self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)
  34.     def forward(self, x):
  35.         y1 = self.conv1(x)
  36.         y2 = self.conv2(x)
  37.         mid_out = [y1]
  38.         for conv in self.convs:
  39.             y2 = conv(y2)
  40.             mid_out.append(y2)
  41.         y = torch.cat(mid_out, axis=1)
  42.         y = self.conv3(y)
  43.         return y
复制代码
以上是CSPStage的代码,要想看懂,我们得先看懂ConvBNAct、BasicBlock_3x3_Reverse这两个类
  1. class ConvBNAct(nn.Module):
  2.     """A Conv2d -> Batchnorm -> silu/leaky relu block"""
  3.     def __init__(
  4.         self,
  5.         in_channels,
  6.         out_channels,
  7.         ksize,
  8.         stride=1,
  9.         groups=1,
  10.         bias=False,
  11.         act='silu',
  12.         norm='bn',
  13.         reparam=False,
  14.     ):
  15.         super().__init__()
  16.         # same padding
  17.         pad = (ksize - 1) // 2
  18.         self.conv = nn.Conv2d(
  19.             in_channels,
  20.             out_channels,
  21.             kernel_size=ksize,
  22.             stride=stride,
  23.             padding=pad,
  24.             groups=groups,
  25.             bias=bias,
  26.         )
  27.         if norm is not None:
  28.             self.bn = get_norm(norm, out_channels, inplace=True)
  29.         if act is not None:
  30.             self.act = get_activation(act, inplace=True)
  31.         self.with_norm = norm is not None
  32.         self.with_act = act is not None
  33.     def forward(self, x):
  34.         x = self.conv(x)
  35.         if self.with_norm:
  36.             x = self.bn(x)
  37.         if self.with_act:
  38.             x = self.act(x)
  39.         return x
  40.     def fuseforward(self, x):
  41.         return self.act(self.conv(x))
复制代码
ConvBNAct照旧很悦目懂的,Conv +BN + SiLU就完事了(也可用别的激活函数,文章用SiLU)

 假如设置了groups参数就酿成了组卷积了


  1. class BasicBlock_3x3_Reverse(nn.Module):
  2.     def __init__(self,
  3.                  ch_in,
  4.                  ch_hidden_ratio,
  5.                  ch_out,
  6.                  act='relu',
  7.                  shortcut=True):
  8.         super(BasicBlock_3x3_Reverse, self).__init__()
  9.         assert ch_in == ch_out
  10.         ch_hidden = int(ch_in * ch_hidden_ratio)
  11.         self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act)
  12.         self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act)
  13.         self.shortcut = shortcut
  14.     def forward(self, x):
  15.         y = self.conv2(x)
  16.         y = self.conv1(y)
  17.         if self.shortcut:
  18.             return x + y
  19.         else:
  20.             return y
复制代码
要看懂BasicBlock_3x3_Reverse这个类,就得了解RepConv类,这个类就是根据RepVGG网络的RepVGGBlock改的
  1. class RepConv(nn.Module):
  2.     '''RepConv is a basic rep-style block, including training and deploy status
  3.     Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  4.     '''
  5.     def __init__(self,
  6.                  in_channels,
  7.                  out_channels,
  8.                  kernel_size=3,
  9.                  stride=1,
  10.                  padding=1,
  11.                  dilation=1,
  12.                  groups=1,
  13.                  padding_mode='zeros',
  14.                  deploy=False,
  15.                  act='relu',
  16.                  norm=None):
  17.         super(RepConv, self).__init__()
  18.         self.deploy = deploy
  19.         self.groups = groups
  20.         self.in_channels = in_channels
  21.         self.out_channels = out_channels
  22.         assert kernel_size == 3
  23.         assert padding == 1
  24.         padding_11 = padding - kernel_size // 2
  25.         if isinstance(act, str):
  26.             self.nonlinearity = get_activation(act)
  27.         else:
  28.             self.nonlinearity = act
  29.         if deploy:
  30.             self.rbr_reparam = nn.Conv2d(in_channels=in_channels,
  31.                                          out_channels=out_channels,
  32.                                          kernel_size=kernel_size,
  33.                                          stride=stride,
  34.                                          padding=padding,
  35.                                          dilation=dilation,
  36.                                          groups=groups,
  37.                                          bias=True,
  38.                                          padding_mode=padding_mode)
  39.         else:
  40.             self.rbr_identity = None
  41.             self.rbr_dense = conv_bn(in_channels=in_channels,
  42.                                      out_channels=out_channels,
  43.                                      kernel_size=kernel_size,
  44.                                      stride=stride,
  45.                                      padding=padding,
  46.                                      groups=groups)
  47.             self.rbr_1x1 = conv_bn(in_channels=in_channels,
  48.                                    out_channels=out_channels,
  49.                                    kernel_size=1,
  50.                                    stride=stride,
  51.                                    padding=padding_11,
  52.                                    groups=groups)
  53.     def forward(self, inputs):
  54.         '''Forward process'''
  55.         if hasattr(self, 'rbr_reparam'):
  56.             return self.nonlinearity(self.rbr_reparam(inputs))
  57.         if self.rbr_identity is None:
  58.             id_out = 0
  59.         else:
  60.             id_out = self.rbr_identity(inputs)
  61.         return self.nonlinearity(
  62.             self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
  63.     def get_equivalent_kernel_bias(self):
  64.         kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  65.         kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  66.         kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  67.         return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  68.             kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  69.     def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  70.         if kernel1x1 is None:
  71.             return 0
  72.         else:
  73.             return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  74.     def _fuse_bn_tensor(self, branch):
  75.         if branch is None:
  76.             return 0, 0
  77.         if isinstance(branch, nn.Sequential):
  78.             kernel = branch.conv.weight
  79.             running_mean = branch.bn.running_mean
  80.             running_var = branch.bn.running_var
  81.             gamma = branch.bn.weight
  82.             beta = branch.bn.bias
  83.             eps = branch.bn.eps
  84.         else:
  85.             assert isinstance(branch, nn.BatchNorm2d)
  86.             if not hasattr(self, 'id_tensor'):
  87.                 input_dim = self.in_channels // self.groups
  88.                 kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
  89.                                         dtype=np.float32)
  90.                 for i in range(self.in_channels):
  91.                     kernel_value[i, i % input_dim, 1, 1] = 1
  92.                 self.id_tensor = torch.from_numpy(kernel_value).to(
  93.                     branch.weight.device)
  94.             kernel = self.id_tensor
  95.             running_mean = branch.running_mean
  96.             running_var = branch.running_var
  97.             gamma = branch.weight
  98.             beta = branch.bias
  99.             eps = branch.eps
  100.         std = (running_var + eps).sqrt()
  101.         t = (gamma / std).reshape(-1, 1, 1, 1)
  102.         return kernel * t, beta - running_mean * gamma / std
  103.     def switch_to_deploy(self):
  104.         if hasattr(self, 'rbr_reparam'):
  105.             return
  106.         kernel, bias = self.get_equivalent_kernel_bias()
  107.         self.rbr_reparam = nn.Conv2d(
  108.             in_channels=self.rbr_dense.conv.in_channels,
  109.             out_channels=self.rbr_dense.conv.out_channels,
  110.             kernel_size=self.rbr_dense.conv.kernel_size,
  111.             stride=self.rbr_dense.conv.stride,
  112.             padding=self.rbr_dense.conv.padding,
  113.             dilation=self.rbr_dense.conv.dilation,
  114.             groups=self.rbr_dense.conv.groups,
  115.             bias=True)
  116.         self.rbr_reparam.weight.data = kernel
  117.         self.rbr_reparam.bias.data = bias
  118.         for para in self.parameters():
  119.             para.detach_()
  120.         self.__delattr__('rbr_dense')
  121.         self.__delattr__('rbr_1x1')
  122.         if hasattr(self, 'rbr_identity'):
  123.             self.__delattr__('rbr_identity')
  124.         if hasattr(self, 'id_tensor'):
  125.             self.__delattr__('id_tensor')
  126.         self.deploy = True
复制代码

 RepConv的特点是布局重参数化,训练时采用三条分支,推理时将三个分支融合在一起,大大淘汰了推理时间(发起看看RepVGG的讲授视频),我图画得太丑了

  RepConv采用的两分支的布局(a)



 其他细节有缘再更,代码不难,渐渐看完万能懂。有写的不对的地方请包涵


来源:https://blog.csdn.net/weixin_43227262/article/details/129368505
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则