您的位置:

如何正确实现NN中的Batch Normalization技术

一、什么是Batch Normalization?

在神经网络模型的训练过程中,每个输入(或者中间层的输出)都需要标准化,使其满足均值为0,标准差为1的条件。但是,标准化只是将输入数据映射到同一个尺度上,而Batch Normalization则引入了标准差和偏移量两个学习参数,可以让神经网络更加高效地学习。

Batch Normalization的做法是在训练过程中对每一 mini-batch 的输入数据计算均值和方差,并按照公式进行重归一化。

def batchnorm_forward(X, gamma, beta, eps):
    mu = np.mean(X, axis=0)
    var = np.var(X, axis=0)
    std = np.sqrt(var + eps)
    X_norm = (X - mu) / std
    out = gamma * X_norm + beta
    return out, (X, mu, var, eps)

二、Batch Normalization的优点

对神经网络进行训练时,网络在每一层中可能会产生内部协变量偏移(Internal Covariate Shift)问题,即前面层的参数的更新会对后面层产生影响,而使得后面层需要再次训练的情况可能发生。

Batch Normalization的引入,解决了这个问题,可以让神经网络在进行较快的训练时,更加稳定和高效。而且它还能够通过增加网络深度进一步提高网络性能。

优点总结:

  • 较小的学习率。
  • 网络具有更强的归一化能力,使得输入数据分布在一个范围内。
  • 可以加速网络的训练,减少时间和成本开销。

三、如何正确实现Batch Normalization?

为了正确地实现Batch Normalization,需要注意以下几个方面:

1、注意Batch Normalization放在哪个位置

在神经网络模型中,通常把Batch Normalization放在全连接层和激活函数之间,而在卷积神经网络中,通常把Batch Normalization放在卷积和激活函数之间。

放置位置的代码表示如下:

class Convolution:
    def __init__(self, W, b, stride=1, pad=0):
        self.W = W
        self.b = b
        self.stride = stride
        self.pad = pad
  
    def forward(self, x):
        N, C, H, W = x.shape
        F, C, HH, WW = self.W.shape
    
        H_out = 1 + int((H + 2*self.pad - HH) / self.stride)
        W_out = 1 + int((W + 2*self.pad - WW) / self.stride)
    
        out = np.zeros((N, F, H_out, W_out))
    
        x_pad = np.pad(x, ((0,), (0,), (self.pad,), (self.pad,)), 'constant')
    
        for i in range(H_out):
            for j in range(W_out):
                h_start = i*self.stride
                h_end = h_start + HH
                w_start = j*self.stride
                w_end = w_start + WW
                
                x_pad_masked = x_pad[:, :, h_start:h_end, w_start:w_end]
                for k in range(F):
                    out[:, k, i, j] = np.sum(x_pad_masked * self.W[k, :, :, :], axis=(1,2,3))
        self.cache = (x, x_pad, out)
        return out

    def backward(self, dout):
        x, x_pad, out = self.cache
        
        N, C, H, W = x.shape
        F, C, HH, WW = self.W.shape
    
        H_out = 1 + int((H + 2*self.pad - HH) / self.stride)
        W_out = 1 + int((W + 2*self.pad - WW) / self.stride)
    
        dx_pad = np.zeros_like(x_pad)
        dW = np.zeros_like(self.W)
        db = np.sum(dout, axis=(0,2,3))
    
        for i in range(H_out):
            for j in range(W_out):
                h_start = i*self.stride
                h_end = h_start + HH
                w_start = j*self.stride
                w_end = w_start + WW
                
                x_pad_masked = x_pad[:, :, h_start:h_end, w_start:w_end]
                for k in range(F):
                    dW[k, :, :, :] += np.sum((x_pad_masked * (dout[:, k, i, j])[:, None, None, None]), axis=0)
                
                    dx_pad[:, :, h_start:h_end, w_start:w_end] += (self.W[k, :, :, :] * (dout[:, k, i, j])[:, None, None, None])
        dx = dx_pad[:, :, self.pad:-self.pad, self.pad:-self.pad]
    
        return dx, dW, db


class BatchNormalization:
    def __init__(self, gamma, beta, momentum=0.9, eps=1e-5):
        self.gamma = gamma
        self.beta = beta
        self.momentum = momentum
        self.eps = eps
    
    def forward(self, x, is_training=True):
        N, D = x.shape
        if x.ndim != 2:
            temp = x.transpose(0,2,3,1)
            temp = temp.reshape((N, -1))
            x = temp
    
        self.mu = np.mean(x, axis=0)
        self.var = np.var(x, axis=0)
        self.std = np.sqrt(self.var + self.eps)
        self.x_norm = (x - self.mu) / self.std
        out = self.gamma * self.x_norm + self.beta
    
        if x.ndim != 2:
            out = out.reshape(*temp.shape)
            out = out.transpose(0,3,1,2)
    
        if is_training:
            self.x = x
            self.out = out
        return out
    
    def backward(self, dout):
        N, D = dout.shape
        if dout.ndim != 2:
            temp = dout.transpose(0,2,3,1)
            temp = temp.reshape((N, -1))
            dout = temp
    
        dx_norm = dout * self.gamma
        
        dvar = np.sum(dx_norm * (self.x - self.mu), axis=0) * (-0.5) * (self.std ** -3)
        dmu = np.sum(dx_norm * (-1 / self.std), axis=0) + dvar * (-2 / N * np.sum(self.x - self.mu, axis=0))
            
        dx = dx_norm / self.std + dvar * 2 / N * (self.x - self.mu) + dmu / N
        
        dgamma = np.sum(dout * self.x_norm, axis=0)
        dbeta = np.sum(dout, axis=0)
        
        if dout.ndim != 2:
            dx = dx.reshape(*temp.shape)
            dx = dx.transpose(0,3,1,2)
        
        return dx, dgamma, dbeta

2、注意学习率的设置

Batch Normalization的引入会改变神经网络的权重更新方式,因此会影响模型的训练速度和交叉验证的精度等,因此需要合理设置学习率。

以下为学习率的代码设置:

optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

3、Batch Normalization的训练和测试模式

在测试模式下,需要保存模型中Batch Normalization层中累积的均值和方差。

以下是训练和测试代码的设置:

# 训练
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

# 测试
model.eval()
with torch.no_grad():
    test_outputs = model(test_inputs)

四、Batch Normalization的实际应用

Batch Normalization技术的实际应用包括CNN(卷积神经网络)、MLP(多层感知机网络)等。

以下是CNN应用中Batch Normalization的代码实现:

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1_bn = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(64 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = functional.relu(self.conv1_bn(self.conv1(x)))
        x = self.pool1(x)
        x = functional.relu(self.conv2_bn(self.conv2(x)))
        x = self.pool2(x)
        x = x.view(-1, 64 * 8 * 8)
        x = functional.relu(self.fc1_bn(self.fc1(x)))
        x = self.fc2(x)
        return x

五、小结

本文说明了Batch Normalization的概念、优点、正确实现的方法以及在神经网络中的实际应用。要正确实现Batch Normalization,需要注意Batch Normalization放在神经网络中的位置、学习率的设置、训练和测试模式等问题。