您的位置:

tensor list详解

一、什么是tensor list

1、Tensor


import torch
a = torch.tensor([1,2,3])

2、列表


lst = [1, 2, 3]

结合两者,即是tensor list


lst = [torch.tensor([1,2,3]), torch.tensor([4,5,6])]

二、tensor list的操作

1、赋值操作


lst[0] = torch.tensor([7,8,9])

2、切片操作


lst_slice = lst[:1]

3、拼接操作


new_lst = lst + [torch.tensor([10,11,12])]

三、tensor list的应用

1、神经网络的前向传播


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=1000, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=10),
            nn.ReLU(),
            nn.Linear(in_features=10, out_features=1)
        )

    def forward(self, x):
        xs = []
        for layer in self.fc_layers:
            x = layer(x)
            xs.append(x)
        return xs

2、序列标注的解码


def viterbi_decode(self, emissions: List[Tensor], transitions: Tensor,
                   decode_lengths: Optional[List[int]] = None) -> List[List[int]]:

    max_seq_length, batch_size, _ = emissions.shape
    mask = torch.ones(emissions.shape[:2], dtype=bool, device=emissions.device)
    path_scores = []
    path_indices = []
    last_idx = mask.sum(1) - 1
    # 发射概率
    emissions = emissions.permute(1, 0, 2)

    for i, (emission, batch_mask) in enumerate(zip(emissions, mask)):
        path_score, path_index = emission[0].unsqueeze(1), torch.zeros_like(emission[0]).unsqueeze(1).long()

        for j, (transition, score, last) in enumerate(zip(transitions, emission[1:], last_idx)):
            last = last.long()
            # 1、跨度
            broadcast_idx = batch_mask.unsqueeze(1).unsqueeze(2)
            broadcast_last = last.unsqueeze(1).unsqueeze(2)
            current_scores = path_score + transition + score.unsqueeze(1)
            current_scores[last == j] -= transitions[j]
            # 2、更新
            new_path_scores, new_path_indices = current_scores.max(dim=0)
            new_path_scores = torch.where(broadcast_idx, new_path_scores, path_score)
            new_path_indices = torch.where(broadcast_idx, new_path_indices, path_index)
            new_path_indices[last == j] = j

            path_score, path_index = new_path_scores, new_path_indices

        if decode_lengths is not None:
            path_index = [path_index[l, :dl] for l, dl in enumerate(decode_lengths[i].tolist())]
            path_score = [path_score[:dl, l] for l, dl in enumerate(decode_lengths[i].tolist())]

        path_scores.append(path_score)
        path_indices.append(path_index)

    # 计算总分值
    path_scores = [torch.stack(v).sum(0) for v in path_scores]
    return path_indices

四、tensor list的批量化处理

1、转换张量


# tensor列表转换为一个大张量
tensors = [torch.randn(3, 4), torch.randn(5, 6)]
batched_t = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)

2、打包/解包张量


# 生成一个长度列表,列表中每一个元素代表一个batch的数据的长度(即句子长度)
packed_sequence = torch.nn.utils.rnn.pack_sequence(tensors, enforce_sorted=False)
unpacked_sequence = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)

3、引入mask


# 先加入PAD,再加一个mask,把PAD排除在计算外
padded_sequence = nn.utils.rnn.pad_sequence(batched_tokens, batch_first=True, padding_value=vocab.token_to_idx[PAD])
mask = padded_sequence != vocab.token_to_idx[PAD]

五、总结与展望

Tensor List是PyTorch中经常使用到的数据结构,广泛应用于深度学习领域中的多个任务中,如神经网络的前向传播、序列标注的解码等。结合PyTorch自身的优势,我们可以高效地处理大量的数据,并实现了更加优秀的深度学习算法。在未来,Tensor List继续地用于深度学习算法实现中,我们期待Tensor List的技术在处理海量数据、提升模型精度、加速模型训练等多方面都可以有所突破。