您的位置:

深度学习中的推荐算法SASRec

一、SASRec简介

SASRec是一种仅基于序列信息进行推荐的深度神经网络模型。相比于其他推荐算法,SASRec有较高的准确率和效率,同时能够处理长时序列数据。

SASRec最初由Wang等人在2018年发表在ICDM上,它的主要思想是将序列数据看作序列交互模式, 每个交互都提供一些期望和隐式反馈。SASRec通过使用自注意力机制替代掉传统的卷积或RNN模型,从而避免了序列信息丢失的问题。

二、SASRec的架构与算法实现

1.架构

SASRec的架构主要由两部分组成:序列-序列网络(SSN)和非线性全连接层。SSN是一个基于自注意力机制的深度神经网络,其主要作用是对序列交互模式进行建模并产生预测结果。非线性全连接层作为输出层,将SSN的最后一层隐藏状态转化为目标项的预测结果。

class SASRec(nn.Module):
    def __init__(self, item_size, dim, num_heads, num_blocks, max_len, dropout):
        super(SASRec, self).__init__()
        self.item_size = item_size
        self.dim = dim
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.max_len = max_len

        self.item_embeddings = nn.Embedding(item_size, dim, padding_idx=0)
        self.pos_embedding = nn.Embedding(max_len, dim)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(dim, num_heads, dropout) for _ in range(num_blocks)])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(dim, item_size)

    def forward(self, items, seq_len):
        # items: B x L
        # seq_len: B
        item_embs = self.item_embeddings(items)
        pos = torch.arange(self.max_len - 1, -1, -1.).to(items.device)
        pos_embs = self.pos_embedding(pos)
        embs = item_embs + pos_embs
        mask = get_mask(seq_len, self.max_len)
        h = embs
        for transformer in self.transformer_blocks:
            h = transformer(h, mask)
        h = self.dropout(h[:, 0, :])
        output = self.fc(h)
        return output

2.算法实现

SASRec的算法实现基于Pytorch深度学习框架,由若干个模块组成,包括TransformerBlock、PositionwiseFeedForward、MultiHeadAttention以及SASRec本体模块。

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, dropout):
        super(TransformerBlock, self).__init__()
        self.attn_layer_norm = nn.LayerNorm(dim)
        self.ffn_layer_norm = nn.LayerNorm(dim)
        self.ffn = PositionwiseFeedForward(dim, dropout)
        self.attn = MultiHeadAttention(num_heads, dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # x: B x L x H
        # mask: B x L
        # self-attention
        h = self.attn_layer_norm(x)
        h, _ = self.attn(h, h, h, mask)
        h = self.dropout(h) + x
        # feed forward
        o = self.ffn_layer_norm(h)
        o = self.dropout(self.ffn(o)) + h
        return o

三、SASRec的优缺点

1.优点

(1)SASRec直接依赖于序列数据,可以更加充分地利用序列中隐藏的用户兴趣信息。

(2)SASRec通过自注意力机制对长序列进行建模和处理,可以有效地保留序列中的重要信息,提取有用的特征。

(3)SASRec的训练效率相比于其他深度学习算法较高,而且其预测效果优秀。

2.缺点

(1)对于非序列信息存在的情况,SASRec的效果可能不如其他的推荐算法。

(2)SASRec的用户兴趣建模受限于序列长度,如果序列过短或过长,其预测效果可能会受到影响。

四、SASRec的应用

SASRec目前已经在多个推荐系统中得到了应用,例如Amazon和Netflix等。此外,基于SASRec提出的多个改进算法,如SR-Transformer,也在不断探索中应用于实际的推荐场景。

五、总结

SASRec是一种基于序列信息的高效推荐算法,其核心思想是使用自注意力机制进行序列建模并进行预测。SASRec具有较高的准确性和处理能力,在推荐系统的实际应用中具有广阔的前景。