Трансформатор Fairseq Декодер Прием Энкодер Проблема с несколькими выходами

Fairseq использует декодер для получения нескольких выходных данных от кодировщика, которые содержат несколько промежуточных состояний кодировщика. Два глобальных массива encoder_inner_states = [] final_encoder_inner_states = [] кодировщик:

def forward(self, src_tokens, src_lengths):
    """
    Args:
        src_tokens (LongTensor): tokens in the source language of shape
            `(batch, src_len)`
        src_lengths (torch.LongTensor): lengths of each source sentence of
            shape `(batch)`

    Returns:
        dict:
            - **encoder_out** (Tensor): the last encoder layer's output of
              shape `(src_len, batch, embed_dim)`
            - **encoder_padding_mask** (ByteTensor): the positions of
              padding elements of shape `(batch, src_len)`
    """
    encoder_inner_states.clear()
    if self.history is not None:
        self.history.clean()
    # embed tokens and positions
    x = self.embed_scale * self.embed_tokens(src_tokens)
    if self.embed_positions is not None:
        x += self.embed_positions(src_tokens)
    x = F.dropout(x, p=self.dropout, training=self.training)

    # B x T x C -> T x B x C
    x = x.transpose(0, 1)

    # add emb into history
    if self.history is not None:
        self.history.add(x)

    # compute padding mask
    encoder_padding_mask = src_tokens.eq(self.padding_idx)
    if not encoder_padding_mask.any():
        encoder_padding_mask = None
    #print("emb:{}".format(x.size()))
    
    #intra_sim
    attn_weight_list = []
    inner_states=[]
    # encoder layers
    for layer_id,layer in enumerate(self.layers):
        #if layer_id == 2 :
            #continue
        x, attn_weight = layer(x, encoder_padding_mask)

        if self.history is not None:
            self.history.add(x)
        if self.history is not None:
            y = self.history.pop()

        encoder_inner_states.append(x)
        attn_weight_list.append(attn_weight)



    if self.history is not None:
        x = self.history.pop()

    final_encoder_inner_states.clear()
    for fc_layer_id,fc_layer in enumerate(self.fc_layers):
        y = fc_layer(encoder_inner_states[fc_layer_id])
        y = F.dropout(y, p=self.dropout, training=self.training)
        if self.normalize:
            y = self.layer_norm(y)
        final_encoder_inner_states.append(y)

    #self.print_attn_weight(attn_weight_list)
    #print("encoder-out:{}".format(x))
    return {
        'encoder_out': x,# T x B x C
        'encoder_padding_mask': encoder_padding_mask,  # B x T
    }

декодер: вместо обычного encoder_out используются данные в глобальном массиве

def forward(self, prev_output_tokens  ,encoder_out=None ,incremental_state=None):
    """
    Args:
        prev_output_tokens (LongTensor): previous decoder outputs of shape
            `(batch, tgt_len)`, for input feeding/teacher forcing
        encoder_out (Tensor, optional): output from the encoder, used for
            encoder-side attention
        incremental_state (dict): dictionary used for storing state during
            :param inner_state:
            :ref:`Incremental decoding`

    Returns:
        tuple:
            - the last decoder layer's output of shape `(batch, tgt_len,
              vocab)`
            - the last decoder layer's attention weights of shape `(batch,
              tgt_len, src_len)`
    """
    # embed positions
    positions = self.embed_positions(
        prev_output_tokens,
        incremental_state=incremental_state,
    ) if self.embed_positions is not None else None

    if incremental_state is not None:
        prev_output_tokens = prev_output_tokens[:, -1:]
        if positions is not None:
            positions = positions[:, -1:]

    # embed tokens and positions
    x = self.embed_scale * self.embed_tokens(prev_output_tokens)

    if self.project_in_dim is not None:
        x = self.project_in_dim(x)

    if positions is not None:
        x += positions
    x = F.dropout(x, p=self.dropout, training=self.training)

    # B x T x C -> T x B x C
    x = x.transpose(0, 1)
    attn = None

    inner_states = [x]
    enc_dec_attn_weight_list = []



    # decoder layers
    for layer_id,layer in enumerate(self.layers):

        x, attn = layer(
            x,
            #encoder_out['inner_states'][layer_id],
            #encoder_out['encoder_out'] if encoder_out is not None else None,
            final_encoder_inner_states[layer_id],
            encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
            incremental_state,
            self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
        )
        inner_states.append(x)
        enc_dec_attn_weight_list.append(attn)
    #for i in range(len(enc_dec_attn_weight_list)):
        #print('layer{}'.format(i))
        #print(enc_dec_attn_weight_list[i].cpu().numpy())

    if self.normalize:
        x = self.layer_norm(x)
    # print("decoder-out:{}".format(x))
    # T x B x C -> B x T x C
    x = x.transpose(0, 1)

    if self.project_out_dim is not None:
        x = self.project_out_dim(x)

    if self.adaptive_softmax is None:
        # project back to size of vocabulary
        if self.share_input_output_embed:
            x = F.linear(x, self.embed_tokens.weight)
        else:
            x = F.linear(x, self.embed_out)

    return x, {'attn': attn, 'inner_states': inner_states}

Но у меня следующие проблемы :введите описание изображения здесь

введите описание изображения здесь

0 ответов

Другие вопросы по тегам