一、什么是tensor list
1、Tensor
import torch
a = torch.tensor([1,2,3])
2、列表
lst = [1, 2, 3]
结合两者,即是tensor list
lst = [torch.tensor([1,2,3]), torch.tensor([4,5,6])]
二、tensor list的操作
1、赋值操作
lst[0] = torch.tensor([7,8,9])
2、切片操作
lst_slice = lst[:1]
3、拼接操作
new_lst = lst + [torch.tensor([10,11,12])]
三、tensor list的应用
1、神经网络的前向传播
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc_layers = nn.Sequential(
nn.Linear(in_features=1000, out_features=100),
nn.ReLU(),
nn.Linear(in_features=100, out_features=10),
nn.ReLU(),
nn.Linear(in_features=10, out_features=1)
)
def forward(self, x):
xs = []
for layer in self.fc_layers:
x = layer(x)
xs.append(x)
return xs
2、序列标注的解码
def viterbi_decode(self, emissions: List[Tensor], transitions: Tensor,
decode_lengths: Optional[List[int]] = None) -> List[List[int]]:
max_seq_length, batch_size, _ = emissions.shape
mask = torch.ones(emissions.shape[:2], dtype=bool, device=emissions.device)
path_scores = []
path_indices = []
last_idx = mask.sum(1) - 1
# 发射概率
emissions = emissions.permute(1, 0, 2)
for i, (emission, batch_mask) in enumerate(zip(emissions, mask)):
path_score, path_index = emission[0].unsqueeze(1), torch.zeros_like(emission[0]).unsqueeze(1).long()
for j, (transition, score, last) in enumerate(zip(transitions, emission[1:], last_idx)):
last = last.long()
# 1、跨度
broadcast_idx = batch_mask.unsqueeze(1).unsqueeze(2)
broadcast_last = last.unsqueeze(1).unsqueeze(2)
current_scores = path_score + transition + score.unsqueeze(1)
current_scores[last == j] -= transitions[j]
# 2、更新
new_path_scores, new_path_indices = current_scores.max(dim=0)
new_path_scores = torch.where(broadcast_idx, new_path_scores, path_score)
new_path_indices = torch.where(broadcast_idx, new_path_indices, path_index)
new_path_indices[last == j] = j
path_score, path_index = new_path_scores, new_path_indices
if decode_lengths is not None:
path_index = [path_index[l, :dl] for l, dl in enumerate(decode_lengths[i].tolist())]
path_score = [path_score[:dl, l] for l, dl in enumerate(decode_lengths[i].tolist())]
path_scores.append(path_score)
path_indices.append(path_index)
# 计算总分值
path_scores = [torch.stack(v).sum(0) for v in path_scores]
return path_indices
四、tensor list的批量化处理
1、转换张量
# tensor列表转换为一个大张量
tensors = [torch.randn(3, 4), torch.randn(5, 6)]
batched_t = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
2、打包/解包张量
# 生成一个长度列表,列表中每一个元素代表一个batch的数据的长度(即句子长度)
packed_sequence = torch.nn.utils.rnn.pack_sequence(tensors, enforce_sorted=False)
unpacked_sequence = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)
3、引入mask
# 先加入PAD,再加一个mask,把PAD排除在计算外
padded_sequence = nn.utils.rnn.pad_sequence(batched_tokens, batch_first=True, padding_value=vocab.token_to_idx[PAD])
mask = padded_sequence != vocab.token_to_idx[PAD]
五、总结与展望
Tensor List是PyTorch中经常使用到的数据结构,广泛应用于深度学习领域中的多个任务中,如神经网络的前向传播、序列标注的解码等。结合PyTorch自身的优势,我们可以高效地处理大量的数据,并实现了更加优秀的深度学习算法。在未来,Tensor List继续地用于深度学习算法实现中,我们期待Tensor List的技术在处理海量数据、提升模型精度、加速模型训练等多方面都可以有所突破。