Трансформатор Fairseq Декодер Прием Энкодер Проблема с несколькими выходами
Fairseq использует декодер для получения нескольких выходных данных от кодировщика, которые содержат несколько промежуточных состояний кодировщика. Два глобальных массива encoder_inner_states = [] final_encoder_inner_states = [] кодировщик:
def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
encoder_inner_states.clear()
if self.history is not None:
self.history.clean()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# add emb into history
if self.history is not None:
self.history.add(x)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
#print("emb:{}".format(x.size()))
#intra_sim
attn_weight_list = []
inner_states=[]
# encoder layers
for layer_id,layer in enumerate(self.layers):
#if layer_id == 2 :
#continue
x, attn_weight = layer(x, encoder_padding_mask)
if self.history is not None:
self.history.add(x)
if self.history is not None:
y = self.history.pop()
encoder_inner_states.append(x)
attn_weight_list.append(attn_weight)
if self.history is not None:
x = self.history.pop()
final_encoder_inner_states.clear()
for fc_layer_id,fc_layer in enumerate(self.fc_layers):
y = fc_layer(encoder_inner_states[fc_layer_id])
y = F.dropout(y, p=self.dropout, training=self.training)
if self.normalize:
y = self.layer_norm(y)
final_encoder_inner_states.append(y)
#self.print_attn_weight(attn_weight_list)
#print("encoder-out:{}".format(x))
return {
'encoder_out': x,# T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
декодер: вместо обычного encoder_out используются данные в глобальном массиве
def forward(self, prev_output_tokens ,encoder_out=None ,incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:param inner_state:
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states = [x]
enc_dec_attn_weight_list = []
# decoder layers
for layer_id,layer in enumerate(self.layers):
x, attn = layer(
x,
#encoder_out['inner_states'][layer_id],
#encoder_out['encoder_out'] if encoder_out is not None else None,
final_encoder_inner_states[layer_id],
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
)
inner_states.append(x)
enc_dec_attn_weight_list.append(attn)
#for i in range(len(enc_dec_attn_weight_list)):
#print('layer{}'.format(i))
#print(enc_dec_attn_weight_list[i].cpu().numpy())
if self.normalize:
x = self.layer_norm(x)
# print("decoder-out:{}".format(x))
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, {'attn': attn, 'inner_states': inner_states}
Но у меня следующие проблемы :введите описание изображения здесь