在Pytorch实现seq2seq模型中,对于一个batch中的每个序列,其长度可能不一致。对于长度不一致的序列,需要进行pad操作,使其长度一致。但是,在计算loss的时候,pad部分的贡献必须要被剔除,否则会带来噪声。
为了解决这一问题,可以使用mask技术,即使用一个mask张量对loss进行掩码,将pad部分设置为0,只计算有效部分的loss。
下面是实现seq2seq时对loss进行mask的方式的完整攻略:
1.创建mask张量
通过给定的输入序列长度,创建一个bool掩码,其中有效部分为True,pad部分为False。
def create_mask(seq_len, pad_idx):
mask = (torch.ones(seq_len) * pad_idx).unsqueeze(0) != torch.arange(seq_len).unsqueeze(1)
return mask.to(device)
其中,seq_len为每个序列的长度,pad_idx为pad的token索引,此处默认使用0进行pad。
2.计算loss时掩码
在计算loss时,将mask张量与计算得到的loss张量相乘即可实现mask。
mask = create_mask(target_seq_len, pad_idx) # 创建mask张量
loss = criterion(output, target_seqs) # 计算loss
loss = (loss * mask.float()).sum() / mask.sum() # mask掩码
3.示例说明
下面给出两个示例,更好地理解如何使用mask对seq2seq模型的loss进行掩码。
假设我们有如下两个序列:
其中,我们使用3个token来表示输入和输出序列,对应的pad_idx为0。那么,我们需要将输入和输出序列转换为相同的长度,这里设定为5。那么,经过pad之后,就可以得到如下矩阵:
# input_seq:['I', 'love', 'you']
input_seqs = [[1, 3, 2, 0, 0]] # 0表示pad
# target_seq:['Ich', 'liebe', 'dich']
target_seqs = [[4, 5, 6, 2, 0]] # 0表示pad
其中,1/3/2对应的是输入序列中的'I'/'love'/'you',4/5/6对应的是目标序列中的'Ich'/'liebe'/'dich'。
接下来,我们需要创建掩码张量,对于pad部分置为False,其他部分置为True。
pad_idx = 0
input_seq_len = 3 # 输入序列长度
target_seq_len = 3 # 目标序列长度
input_mask = create_mask(input_seq_len, pad_idx)
# input_mask: [[ True, True, True, False, False]]
target_mask = create_mask(target_seq_len, pad_idx)
# target_mask: [[ True, True, True, False, False]]
最后,计算loss时,使用mask张量掩码:
output = model(input_seqs, input_mask, target_seqs[:, :-1], target_mask[:, :-1])
loss = criterion(output, target_seqs[:, 1:])
# 对验证集batch中每个序列的loss进行求和并求平均
loss = (loss * target_mask[:, 1:].float()).sum() / target_mask[:, 1:].sum()
这里,我们首先使用model计算模型输出,然后计算loss,最后使用target_mask掩码。需要注意的是,这里的target_seqs需要去掉最后的一个token,也就是'pad',以保证input_seqs和target_seqs的长度相同。