一、Gumbel Softmax简介
Gumbel Softmax是一种基于采样的概率分布生成算法,它用于从一个具有固定参数的分布中生成一组概率分布。 具体地说,它可以通过使用伯努利分布对样本进行采样来生成一个概率分布序列。该算法的应用包括生成离散变量的序列和特权探测机制等。
通俗点解释Gumbel分布就是从两个独立的均匀分布变量中减去log(-log(U))的值的和,其中U是从均匀分布中随机采样的。Gumbel Softmax随机向量的生成操作包括两个步骤:
1、从一个Gumbel(0,1)分布采样,并使用负对数对其进行缩放
2、通过Softmax函数将结果转换为一个概率向量(一组凸和组件)
import torch
def gumbel_softmax_sample(logits, temperature):
y = logits + torch.randn_like(logits)
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
"""
ST-gum: *ST*ochastic *GUM*ble-softmax.
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = torch.zeros_like(y)
max_value, max_index = y.max(dim=-1, keepdim=True)
y_hard.scatter_(dim=-1, index=max_index, value=1.0)
y = (y_hard - y).detach() + y
return y
二、Gumbel Softmax的生成过程
假设我们有一组由骰子掷出的结果构成的序列,该序列中每个骰子掷出的数字之和为10。如果我们知道有多少种不同的序列可以得到这个和,我们就可以得到一个概率分布,该分布揭示了对于所有可能的序列而言,生成和为10的序列的概率是多少。在Gumbel Softmax中,我们使用负对数softmax将掷骰子的操作抽象为随机变量采样并将结果映射到概率分布空间上的一组向量。这里掷骰过程的示例代码:
num_trials, num_faces, target_value = 1000, 10, 10
dice_faces = torch.randint(1, num_faces + 1, size=(num_trials, target_value))
cumulative_sum = dice_faces.cumsum(dim=1)
indicator = (cumulative_sum == target_value)
target_count = indicator.sum(dim=0)
plt.figure(figsize=(8, 6))
plt.hist(target_count.numpy(), bins=np.arange(6, 40), density=True)
plt.xlabel('Number of successful events')
plt.ylabel('Probability')
plt.title('10d10 success count')
plt.show()
三、Gumbel Softmax的应用场景
Gumbel Softmax的应用场景主要涉及到使用生成模型处理离散数据,具体包括: 1、离散序列生成,即通过输入生成符合要求的离散序列; 2、文本生成,即用于自然语言处理中,基于巨量的训练数据进行建模,能够生成新的语言句子; 3、推荐系统,即基于大数据模型,进行用户行为分析和个性化推荐。 以上三个应用场景在神经网络建模中占有重要的地位,由于该模型具备相对较强的分布拟合能力和计算效率,被广泛应用于当代深度学习模型中。
四、Gumbel Softmax的优缺点
优点: 1、Gumbel Softmax算法快速且高效,适用于大规模离散数据的建模和模拟; 2、Gumbel Softmax算法显著优于其他基于概率分布手段生成离散序列的算法,具备更强的分布拟合能力和高阶统计特性; 缺点: 1、Gumbel Softmax算法对于小型数据集处理效果并不优秀,对于输入空间受限的生成模型表现并不理想; 2、Gumbel Softmax算法存在监督数据缺失问题,对于与数据样本无法自动识别的离散空间作用不佳; 3、Gumbel Softmax算法中存在过热问题,具体来说,由于采样过程中的噪声,模型可能会生成具有极小概率的事件,这会对生成效果产生不利影响。