恰当新手搭建ResNet50残差网络的架构图(最全)

[复制链接]
查看588 | 回复0 | 2023-8-23 11:59:59 | 显示全部楼层 |阅读模式
恰当新手搭建ResNet50残差网络的架构图+代码(最全)
网上的教程大多复杂难明,不恰当新手,本来神经网络就难,这些教程本身更难,对新手极度不友好,因此本身做的这个架构图和写的代码,面向新手,大神跳过

  1. from torch import nn
  2. import torch
  3. from torchviz import make_dot
  4. class box(nn.Module):
  5.     def __init__(self, in_channels, index=999, stride=1, downsample=False):
  6.         super(box, self).__init__()
  7.         last_stride = 2  # 虚残差中卷积核的步距
  8.         if downsample:  # 虚残差结构
  9.             f_out_channnels = in_channels * 2
  10.             out_channels = int(in_channels / 2)
  11.             if index == 0:  # here is first core
  12.                 in_channels = int(in_channels / 2)  # 第一层设置为128,是方便了后面的统一处理
  13.                 out_channels = in_channels
  14.                 f_out_channnels = in_channels * 4
  15.                 last_stride = 1
  16.                 stride = 1
  17.         else:  # 实残差
  18.             f_out_channnels = in_channels * 1
  19.             out_channels = int(in_channels / 4)
  20.         self.downsample = downsample
  21.         self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)
  22.         self.relu = nn.ReLU(inplace=True)
  23.         self.bn1 = nn.BatchNorm2d(out_channels)
  24.         self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  25.         self.relu = nn.ReLU(inplace=True)
  26.         self.bn2 = nn.BatchNorm2d(out_channels)
  27.         self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=f_out_channnels, kernel_size=1, stride=1, padding=0)
  28.         self.bn3 = nn.BatchNorm2d(f_out_channnels)
  29.         self.fe = nn.Sequential(
  30.             nn.Conv2d(in_channels=in_channels, out_channels=f_out_channnels, kernel_size=1, stride=last_stride,  padding=0, bias=False),
  31.             nn.BatchNorm2d(f_out_channnels),
  32.         )
  33.     def forward(self, x):
  34.         identity = x
  35.         if self.downsample:
  36.             identity = self.fe(x)
  37.         x = self.conv1(x)
  38.         x = self.bn1(x)
  39.         x = self.relu(x)
  40.         x = self.conv2(x)
  41.         x = self.bn2(x)
  42.         x = self.relu(x)
  43.         x = self.conv3(x)
  44.         x = self.bn3(x)
  45.         out = x + identity
  46.         out = self.relu(out)
  47.         return out
  48. class New50(nn.Module):
  49.     def __init__(self,in_out, num_classes=4):
  50.         super(New50, self).__init__()
  51.         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2,padding=3, bias=False)
  52.         self.bn1 = nn.BatchNorm2d(64)
  53.         self.relu = nn.ReLU(inplace=True)
  54.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  55.         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
  56.         self.fc = nn.Linear(512 * 4 , num_classes)
  57.         layers = []
  58.         for index, z in enumerate(in_out):
  59.             in_ch = z[0]                                      # 这里通道/2
  60.             layers.append(box(in_channels=in_ch, stride=2, downsample=z[2], index=index))  # 这里处理第一层
  61.             for i in range(1, z[1]):
  62.                 layers.append(box(in_channels=z[3]))  # 这里处理其他两层
  63.         # print(layers)
  64.         self.fes = nn.Sequential(*layers)
  65.     def forward(self, x):
  66.         x = self.conv1(x)
  67.         x = self.bn1(x)
  68.         x = self.relu(x)
  69.         x = self.maxpool(x)
  70.         x = self.fes(x)
  71.         x = self.avgpool(x)
  72.         x = torch.flatten(x, 1)
  73.         x = self.fc(x)
  74.         return x
  75. in_out = [(128, 3, True, 256), (256, 4, True, 512), (512, 6, True, 1024), (1024, 3, True, 2048)]
  76. s = New50(in_out=in_out)
  77. def resnet500():
  78.     return New50(in_out=in_out)
  79. '''每层的第一层输入
  80.    每层重复的次数
  81.    是否走虚残差
  82.    每层的第二个卷积核的输入'''
复制代码
后续还会上传ResNet30,FCN,UNet等架构图和代码。

来源:https://blog.csdn.net/qq_44697987/article/details/128178998
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则