您的位置:

图注意力网络

一、简介

图注意力网络(GAT,Graph Attention Networks),是一种基于注意力机制(Attention Mechanism)的图神经网络(Graph Neural Network),由Petar Veličković等人于2017年提出。与传统的图卷积神经网络(Graph Convolutional Network,简称GCN)相比,GAT不仅考虑了每个节点本身的特征,还考虑了该节点与其周围节点之间的关系,从而表达出全局信息。该网络在图节点分类、节点聚类和图分类等领域表现出色,被广泛应用于社交网络、推荐系统、蛋白质结构预测等领域。

二、注意力机制

注意力机制是GAT中最重要的组成部分之一,该机制通过为每个节点学习一个权重分布来动态地调整网络中每个节点的重要性。在基本的注意力机制中,目标节点与周围节点之间的相似性通过计算内积来量化:

1. 基本的注意力机制


def attention(input_x, neighboors_features):
    # 卷积核初始化,设输入特征维度为d,分为h组
    w = tf.Variable(tf.zeros([input_x.shape[-1], input_x.shape[-1]]))
    a = tf.Variable(tf.zeros([2*input_x.shape[-1], 1]))
    # 卷积核赋值,初始化
    tf.compat.v1.random_normal_initializer(stddev=0.1)()(w)
    tf.compat.v1.random_normal_initializer(stddev=0.1)()(a)
    # 节点矩阵X,和邻居节点矩阵
    h = tf.matmul(input_x, w)   # (N, h)
    f_1, f_2 = tf.reshape(h, [-1, 1, h.shape[-1]]), \
           tf.reshape(neighboors_features, [-1, neighboors_features.shape[1], h.shape[-1]])
    # self-attention 注意力机制本质上就是计算出相似性,这里比较的是节点自身的特征以及相邻节点的特征
    # 相似度结果经过 softmax,使得相似度在每行之间的分布是一个概率分布,最后使用相似度的分布对邻居特征进行加权平均。
    # 然后进行再一次卷积并剪枝
    attention_lev_1 = tf.nn.softmax(tf.reduce_sum(tf.multiply(f_1, f_2), axis=-1))  # (N, M)
    attention_lev_2 = tf.multiply(tf.reshape(attention_lev_1, [-1, neighboors_features.shape[1], 1]), neighboors_features)
    h_level_2 = tf.reduce_sum(attention_lev_2, axis=1)  # (N, h)
    return tf.nn.relu(tf.matmul(h_level_2, w))

一个节点的邻居节点特征与节点本身特征加权平均后得到该节点的新特征向量。

2. 多头注意力机制

在GAT中,作者提出了多头注意力机制,引入了一个超参数heads,代表了图中每个节点可以学到不同注意力权重分布,进而提取具有不同重要性特征:


class MultiHeadAttention(tf.keras.Model):
    
    def __init__(self, num_head, embedding_size, output_dim, feature_map=None, activation=tf.nn.elu):
        super(MultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.embedding_size = embedding_size
        self.output_dim = output_dim
        self.feature_map = feature_map
        self.activation = activation
        
        # 计算特征向量
        self.embedding_w = tf.Variable(tf.compat.v1.random_normal([self.embedding_size, self.output_dim], stddev=0.1))
        self.embedding_b = tf.Variable(tf.zeros([self.output_dim]))
        # 计算注意力权重
        self.attention_w = tf.Variable(tf.compat.v1.random_normal([self.output_dim, 1], stddev=0.1))
        self.attention_b = tf.Variable(tf.zeros([self.num_head, 1]))
        
    def call(self, input_x, neighboors_features):
        assert isinstance(neighboors_features, list)
        input_x_ = tf.reshape(tf.matmul(input_x, self.embedding_w), [-1, self.num_head, self.output_dim])
        n_features = [tf.reshape(tf.matmul(f, self.embedding_w) , [-1, self.num_head, self.output_dim]) for f in neighboors_features]
        # 每一个头都有自己的注意力层。注意力权重通过对节点和邻居节点的特征的内积得到。
        # Attention weights are calculated by dot product of node features and neighbor node features.
        att_val = tf.concat([tf.matmul(tf.nn.leaky_relu(tf.matmul(input_x_, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(input_x_)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(input_x_)[0], 1, 1])), 
                              tf.concat([tf.matmul(tf.nn.leaky_relu(tf.matmul(input_x_, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(input_x_)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(input_x_)[0], 1, 1])), 
                                        tf.matmul(tf.nn.leaky_relu(tf.matmul(nf, tf.tile(tf.expand_dims(self.attention_w, 0), [tf.shape(nf)[0], 1, 1])) + tf.tile(self.attention_b, [tf.shape(nf)[0], 1, 1]))], axis=1)], axis=1)
        att_val = self.activation(att_val + tf.tile(self.embedding_b, [tf.shape(input_x_)[0], 1, 1]))
        # 最后通过一个全连接层进行融合
        out = tf.reduce_mean(att_val, axis=1)
        if self.feature_map:
            out = tf.concat([self.feature_map(input_x), out], axis=1)
        return out

三、GAT 的应用

图注意力网络已经被广泛应用于实际的场景中。

1. 社交网络

社交网络中存在好友、关注等关系,每个用户的影响力都不同,GAT可以将好友之间的关系视作图数据,学习每个用户与其好友需关注的重要性,从而进行信息传播和影响力分析。

2. 推荐系统

推荐系统中用户和物品都可以看作图中的节点,节点之间的边表示用户和物品之间的交互行为,GAT可以对节点进行特征抽取,通过对相邻节点进行注意力加权来提升推荐效果。

3. 蛋白质结构预测

蛋白质结构预测是生物学中的重要问题之一,无法通过传统的生物实验方法进行预测。GAT可以将蛋白质中的氨基酸、二级结构等信息视作图数据,学习节点之间的关系特征,通过注意力机制来预测蛋白质结构。

四、总结

图注意力网络是一种基于注意力机制的图神经网络,适用于处理任意类型的图结构数据。其多头注意力机制能够学习不同注意力权重分布,提取具有不同重要性的特征。该网络已经被广泛应用于社交网络、推荐系统、蛋白质结构预测等领域。