您的位置:

ResNet-18的全面解析

ResNet-18是一种非常著名的深度神经网络,它在ImageNet数据集上表现优异,被广泛应用于计算机视觉领域。本文将从网络结构、Skip connection、残差模块、全局平均池化等多个方面对ResNet-18进行详细的阐述。

一、网络结构

ResNet-18是由18个卷积层和全连接层组成的深度卷积神经网络,这个网络结构中每一个卷积层都有一个残差块,其中包含若干个卷积层和batch normalization层。在卷积层之间,使用了stride=2,stride=1,和shortcut来改善性能。


import torch.nn as nn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = nn.AvgPool2d(kernel_size=4)(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def resnet18():
    return ResNet(BasicBlock, [2,2,2,2])

这个网络结构包含4个残差模块,每个残差模块包含多个BasicBlock,最后一个BasicBlock中做了全局平均池化。输出层为一个完全连接层,用于分类。

二、Skip Connection

ResNet-18网络的核心是skip connection。在传统CNN网络中,前面的层处理信息经过多次池化和卷积后被深度神经网络较深的层所覆盖,导致前面的信息被遗忘,难以训练。skip connection解决了这个问题,可以将前面的信息一路留给更深的网络层,可以直接传递信号而不会使其消失。

在ResNet中,skip connection是一种shortcut,负责将输入源直接传递到残差模块中去。例如,在一个两个卷积层组成的模块中,输入x通过第一个卷积层和relu激活函数之后,跳过第二个卷积层,直接到达该模块的输出,如下所示:

x = conv1(x)
out = conv2(x) + x

这种shortcut实在是太经典了,使得Deep Residual Network成为当时最优秀的网络之一。

三、残差模块

残差模块是ResNet-18网络的基础部分。每一个残差模块都是由两个连接级联的卷积层组合而成。其中的第一个卷积层可以是3*3,5*5或者7*7,第二个卷积层通常都是3*3。

下面是一个基本的残差模块的实现:


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out

在这个残差模块中,残差块有两个卷积层,分别用于卷积源数据和传递经过第一个卷积的结果到一个shortcut通道,使上一个残差块的输出可以和这个残差块的输入之间消除一些高阶多项式的噪声,最终输出结果。

四、全局平均池化

全局平均池化是ResNet-18网络的最后一步操作。它的目的是将特征图中每个像素划分为整个空间的均值,此操作将特征图进一步压缩为单个值,以用作分类器的输入。

下面是一个实现全局平均池化的代码块:


out = nn.AvgPool2d(kernel_size=4)(out)
out = out.view(out.size(0), -1)

在这个代码块中,第一行中的AvgPool2d函数用于把特征映射进行均值池化,第二行基于当前大小调整输出的形状以便于将数据输入到全连接层中。

总结

ResNet-18是一种非常成功的深度神经网络,它不仅可以支持更深层次的架构,而且可以在一定程度上减少过拟合的问题。在本文中,我们对ResNet-18的网络结构、Skip Connection机制、残差模块、全局平均池化等多个方面进行了详尽的解析,并提供了相应的代码示例,帮助读者更好的理解这个深度神经网络。