Передача initial_state в двунаправленный слой RNN в Keras

Я пытаюсь реализовать сеть типа кодер-декодер в Keras, с двунаправленными GRU.

Кажется, работает следующий код

src_input = Input(shape=(5,))
ref_input = Input(shape=(5,))

src_embedding = Embedding(output_dim=300, input_dim=vocab_size)(src_input)
ref_embedding = Embedding(output_dim=300, input_dim=vocab_size)(ref_input)

encoder = Bidirectional(
                GRU(2, return_sequences=True, return_state=True)
        )(src_embedding)

decoder = GRU(2, return_sequences=True)(ref_embedding, initial_state=encoder[1])

Но когда я меняю декодер, чтобы использовать Bidirectional обертка, перестает показывать encoder а также src_input слои в model.summary(), Новый декодер выглядит так:

decoder = Bidirectional(
                GRU(2, return_sequences=True)
        )(ref_embedding, initial_state=encoder[1:])

Выход из model.summary() с двунаправленным декодером.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 5, 300)            6610500   
_________________________________________________________________
bidirectional_2 (Bidirection (None, 5, 4)              3636      
=================================================================
Total params: 6,614,136
Trainable params: 6,614,136
Non-trainable params: 0
_________________________________________________________________

Вопрос: я что-то упускаю, когда прохожу initial_state в Bidirectional декодер? Как я могу это исправить? Есть ли другой способ сделать эту работу?

1 ответ

Решение

Это ошибка. RNN слой реализует __call__ так что тензоры в initial_state могут быть собраны в экземпляр модели. Тем не менее Bidirectional обертка не реализовала это. Так что топологическая информация о initial_state тензор отсутствует и случаются странные ошибки.

Я не знал об этом, когда я осуществлял initial_state за Bidirectional, Это надо исправить сейчас, после этого пиара. Вы можете установить последнюю ветку master на GitHub, чтобы исправить это.

Другие вопросы по тегам