您的位置:

混合密度网络(MDN)简述

一、混合密度网络(MDN)概述

混合密度网络(Mixture Density Network, MDN)是一种基于神经网络的概率模型,可以预测多元输出的概率分布。MDN的前身为混合高斯模型,其本质是将高斯模型扩展至多元输出问题上。

一般来说,MDN可以用于建模连续变量或者分类问题中的多元数据输出。另外,由于MDN还具备自适应能力,因此它可以适用于模型具有复杂噪声结构的情况下,如自然语言处理、声音处理、或者图像识别。

二、混合密度网络(MDN)防止系数为1

在MDN中,防止系数为1是一种常见的技巧,它能够让输出分布变得更加连续,并且防止出现"die-off"的情况。"Die-off"指的是某些输出的分布出现尾部截断的问题,当这种情况发生的时候,模型预测的输出会变得非常敏感。防止系数为1的方法通常是通过将输出分别乘以一个极小的定值(如1e-5)来实现。


def output_tensor(y_class, y_res):
    # 将y_class的形状从(batch_size, 1)转换为(batch_size, K)
    y_class_flat = tf.reshape(y_class, [-1])
    ind = tf.range(tf.shape(y_class_flat)[0]) * K + tf.cast(y_class_flat, tf.int32)
    mu = tf.gather(tf.reshape(y_res[:K * size_out], [-1, size_out]), ind)
    sigma_hat = tf.exp(tf.gather(y_res[K * size_out:(2 * K + 1) * size_out], ind))
    sigma = sigma_hat * tf.pow(1 + tf.pow(sigma_hat, 2) * self.reg, -0.5)  # 适当的防止系数
    alpha = tf.reshape(tf.nn.softmax(tf.reshape(y_res[(2 * K + 1) * size_out:], [-1, K])), [-1, K])

    # 输出mu, sigma和alpha
    return mu, sigma, alpha

三、混合密度网络(MDN)评估

MDN评估一般采用负对数似然(Negative Loglikelihood)来进行。负对数似然是假定观测值服从预测输出分布后,在该分布下的似然函数的相反数。


def nll_loss(y_true, y_pred):
    mu, sigma, alpha = output_tensor(y_pred[:, :1], y_pred[:, 1:])
    gm = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(probs=alpha),
        components_distribution=tfd.Normal(
            loc=mu,
            scale=sigma))
    log_likelihood = gm.log_prob(y_true)
    return -tf.reduce_mean(log_likelihood)

四、混合密度网络(MDN)做分类

对于分类任务,我们可以在MDN末端采用softmax函数来作为每个输出类别的概率分布。在这种情况下,我们需要对损失函数进行改进,采用交叉熵损失函数来代替负对数似然损失函数。


def categorical_nll_loss(y_true, y_pred):
    # 将y_class的形状从(batch_size, 1)转换为(batch_size, K)
    y_class_flat = tf.reshape(y_true[:, :1], [-1])
    ind = tf.range(tf.shape(y_class_flat)[0]) * K + tf.cast(y_class_flat, tf.int32)
    alpha = tf.gather(tf.reshape(y_pred[:, :K * size_out], [-1, size_out]), ind)
    log_likelihood = -tf.math.log(alpha)
    loss = tf.reduce_mean(log_likelihood, axis=-1)
    return loss

五、混合密度网络(MDN)多元回归

多元回归任务通常需要预测多个输出变量,这时我们可以采用多个混合高斯分布来描述多个目标。在这种情况下,我们需要对多元高斯分布求解,具体方法可以采用“联合分布法”或者“条件概率法”。


K = 3  # 采用3个混合高斯分布作为输出

model = Sequential()
model.add(Dense(25, activation='relu'))
model.add(Dense(25, activation='relu'))
model.add(Dense(K * size_out + K + 1, activation='linear'))  # 输出为K * size_out个均值,K * size_out个标准差和K个系数

# 定义损失函数为负对数似然函数
model.compile(loss=nll_loss, optimizer=Adam(lr=0.001))

# 进行训练
model.fit(data_train, label_train, epochs=100)

六、混合密度网络(MDN)多变量输出

多变量输出问题是指输入变量为多维度,输出变量也为多维度的条件概率分布问题。在这种情况下,我们可以采用独立的多元高斯分布来分别描述每个输出维度的条件概率,或者采用图形模型来描述多维度之间的条件概率关系


def create_model(input_shape, output_shape):
    input_layer = Input(shape=input_shape)
    hidden = Dense(units=128, activation='relu')(input_layer)
    hidden = Dense(units=64, activation='relu')(hidden)
    
    # 为每个输出维度定义输出分布
    output_layers = []
    activations = ['linear', 'exponential', 'sigmoid', 'tanh']
    for i in range(output_shape[0]):
        out = Dense(units=3 * output_shape[1], activation=activations[i])(hidden)
        output_layers.append(out)
    output_layer = Concatenate(axis=-1)(output_layers)
    model = Model(inputs=[input_layer], outputs=[output_layer])
    
    # 定义损失函数为负对数似然函数
    model.compile(optimizer='adam', loss=nll_loss)
    return model

七、混合密度网络(MDN)相关论文

A Density Network Approach to Improving the Generalization of Deep Neural Networks

Mixture Density Networks

A Mixture Density Network for Bankruptcy Prediction Using Alternative Data