WideResNet 是 ResNet 的一种改进,是一个由 Zhang、Sun 和 Ross Girshick 在论文 “Wide Residual Networks” 中提出的深度神经网络模型,该网络模型以原始的 ResNet 为基础,将其扩展到更宽的网络上,并在各种图像分类数据集上取得了最优的结果。
一、WideResNet 的简介
传统的 ResNet 构建在残差块上,主要是为了解决梯度消失的问题。而 WideResNet 通过增加通道的宽度来改进模型,可以获得更好的表达能力和分类精度。
WideResNet 模型的特点如下:
- 宽度因子(width factor):利用更宽的卷积核代替更深的层数,并且提高卷积核的数量,以获得更多的鉴别特征。
- 深度因子(depth factor):通过增加残差块的数量来增加网络深度。
- Dropout:使用 Dropout 技术来减少过拟合,防止模型出现高方差。
- Batch Normalization:使用批标准化技术加速训练过程,提高泛化能力。
二、WideResNet 的结构
WideResNet 主要由几个组成部分组成:输入层、卷积层、残差块、全局池化层、全连接层和输出层。
WideResNet 的残差块包括两个卷积层和一个 Skip Connection。在卷积操作后,数据通过 Shortcut 直接连接到输出变量。
def conv3x3(in_channels, out_channels, stride=1): """3x3 convolution with padding""" return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3): super().__init__() self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = conv3x3(in_channels, out_channels, stride) self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None self.bn2 = nn.BatchNorm2d(out_channels) self.conv2 = conv3x3(out_channels, out_channels) if in_channels != out_channels or stride > 1: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) else: self.shortcut = nn.Identity() def forward(self, x): out = F.relu(self.bn1(x), inplace=True) out = self.conv1(out) if self.dropout: out = self.dropout(out) out = self.conv2(F.relu(self.bn2(out), inplace=True)) out += self.shortcut(x) return out
WideResNet 的卷积层包括一个 3×3 的卷积核,用 ReLU 函数作为激活函数,一个 Batch Normalization 层和一个 2×2 的最大池化层。
class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3): super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): out = self.conv(x) out = self.bn(out) out = self.relu(out) if self.dropout: out = self.dropout(out) out = self.pool(out) return out
三、WideResNet 的应用
WideResNet 常用于各种计算机视觉任务,如物体识别、图像分割、场景理解等。在 ImageNet 数据集上,WideResNet 取得了最优的 Top-1 和 Top-5 分类精度。
下面是 WideResNet 在 CIFAR-10 数据集上的完整代码:
import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms class WRN(nn.Module): def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=10): super(WRN, self).__init__() self.depth = depth self.widen_factor = widen_factor self.dropout_rate = dropout_rate self.num_classes = num_classes k = widen_factor # width multiplier # Network architecture n = (depth - 4) // 6 block = BasicBlock channels = [16, 16 * k, 32 * k, 64 * k] self.features = nn.Sequential( ConvLayer(3, channels[0], dropout_rate=dropout_rate), self._make_layer(block, channels[1], n, dropout_rate=dropout_rate), self._make_layer(block, channels[2], n, stride=2, dropout_rate=dropout_rate), self._make_layer(block, channels[3], n, stride=2, dropout_rate=dropout_rate), nn.BatchNorm2d(channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), ) self.classifier = nn.Linear(channels[3], num_classes) # Initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def _make_layer(self, block, out_channels, num_blocks, stride=1, dropout_rate=0): layers = [] for i in range(num_blocks): layers.append( block( 16 * self.widen_factor if i == 0 else out_channels, out_channels, stride if i == 0 else 1, dropout_rate=dropout_rate if i == 0 else 0, ) ) return nn.Sequential(*layers) def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out # Training settings batch_size = 128 epochs = 50 lr = 0.1 momentum = 0.9 weight_decay = 5e-4 device = "cuda" if torch.cuda.is_available() else "cpu" # Load data train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( root="./data", train=True, download=True, transform=transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ), ] ), ), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, ) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10( root="./data", train=False, download=True, transform=transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ), ] ), ), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, ) # Create model model = WRN().to(device) # Optimization optimizer = optim.SGD( model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay ) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) # Training loop for epoch in range(epochs): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss.backward() optimizer.step() scheduler.step() model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.cross_entropy(output, target, reduction="sum").item() pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print(f"Epoch {epoch + 1}/{epochs}") print(f"Train - loss: {loss.item():.4f}") print(f"Test - loss: {test_loss:.4f}, accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n")