您的位置:

WideResNet

WideResNet 是 ResNet 的一种改进,是一个由 Zhang、Sun 和 Ross Girshick 在论文 “Wide Residual Networks” 中提出的深度神经网络模型,该网络模型以原始的 ResNet 为基础,将其扩展到更宽的网络上,并在各种图像分类数据集上取得了最优的结果。

一、WideResNet 的简介

传统的 ResNet 构建在残差块上,主要是为了解决梯度消失的问题。而 WideResNet 通过增加通道的宽度来改进模型,可以获得更好的表达能力和分类精度。

WideResNet 模型的特点如下:

  • 宽度因子(width factor):利用更宽的卷积核代替更深的层数,并且提高卷积核的数量,以获得更多的鉴别特征。
  • 深度因子(depth factor):通过增加残差块的数量来增加网络深度。
  • Dropout:使用 Dropout 技术来减少过拟合,防止模型出现高方差。
  • Batch Normalization:使用批标准化技术加速训练过程,提高泛化能力。

二、WideResNet 的结构

WideResNet 主要由几个组成部分组成:输入层、卷积层、残差块、全局池化层、全连接层和输出层。

WideResNet 的残差块包括两个卷积层和一个 Skip Connection。在卷积操作后,数据通过 Shortcut 直接连接到输出变量。

def conv3x3(in_channels, out_channels, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)

        if in_channels != out_channels or stride > 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(x), inplace=True)
        out = self.conv1(out)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out += self.shortcut(x)
        return out

WideResNet 的卷积层包括一个 3×3 的卷积核,用 ReLU 函数作为激活函数,一个 Batch Normalization 层和一个 2×2 的最大池化层。

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        if self.dropout:
            out = self.dropout(out)
        out = self.pool(out)
        return out

三、WideResNet 的应用

WideResNet 常用于各种计算机视觉任务,如物体识别、图像分割、场景理解等。在 ImageNet 数据集上,WideResNet 取得了最优的 Top-1 和 Top-5 分类精度。

下面是 WideResNet 在 CIFAR-10 数据集上的完整代码:

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

class WRN(nn.Module):
    def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=10):
        super(WRN, self).__init__()
        self.depth = depth
        self.widen_factor = widen_factor
        self.dropout_rate = dropout_rate
        self.num_classes = num_classes

        k = widen_factor  # width multiplier

        # Network architecture
        n = (depth - 4) // 6
        block = BasicBlock
        channels = [16, 16 * k, 32 * k, 64 * k]
        self.features = nn.Sequential(
            ConvLayer(3, channels[0], dropout_rate=dropout_rate),
            self._make_layer(block, channels[1], n, dropout_rate=dropout_rate),
            self._make_layer(block, channels[2], n, stride=2, dropout_rate=dropout_rate),
            self._make_layer(block, channels[3], n, stride=2, dropout_rate=dropout_rate),
            nn.BatchNorm2d(channels[3]),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Linear(channels[3], num_classes)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, out_channels, num_blocks, stride=1, dropout_rate=0):
        layers = []
        for i in range(num_blocks):
            layers.append(
                block(
                    16 * self.widen_factor if i == 0 else out_channels,
                    out_channels,
                    stride if i == 0 else 1,
                    dropout_rate=dropout_rate if i == 0 else 0,
                )
            )
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

# Training settings
batch_size = 128
epochs = 50
lr = 0.1
momentum = 0.9
weight_decay = 5e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load data
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

# Create model
model = WRN().to(device)

# Optimization
optimizer = optim.SGD(
    model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

# Training loop
for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"Train - loss: {loss.item():.4f}")
    print(f"Test  - loss: {test_loss:.4f}, accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n")