Передача initial_state в двунаправленный слой RNN в Keras
Я пытаюсь реализовать сеть типа кодер-декодер в Keras, с двунаправленными GRU.
Кажется, работает следующий код
src_input = Input(shape=(5,))
ref_input = Input(shape=(5,))
src_embedding = Embedding(output_dim=300, input_dim=vocab_size)(src_input)
ref_embedding = Embedding(output_dim=300, input_dim=vocab_size)(ref_input)
encoder = Bidirectional(
GRU(2, return_sequences=True, return_state=True)
)(src_embedding)
decoder = GRU(2, return_sequences=True)(ref_embedding, initial_state=encoder[1])
Но когда я меняю декодер, чтобы использовать Bidirectional
обертка, перестает показывать encoder
а также src_input
слои в model.summary()
, Новый декодер выглядит так:
decoder = Bidirectional(
GRU(2, return_sequences=True)
)(ref_embedding, initial_state=encoder[1:])
Выход из model.summary()
с двунаправленным декодером.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 5) 0
_________________________________________________________________
embedding_2 (Embedding) (None, 5, 300) 6610500
_________________________________________________________________
bidirectional_2 (Bidirection (None, 5, 4) 3636
=================================================================
Total params: 6,614,136
Trainable params: 6,614,136
Non-trainable params: 0
_________________________________________________________________
Вопрос: я что-то упускаю, когда прохожу initial_state
в Bidirectional
декодер? Как я могу это исправить? Есть ли другой способ сделать эту работу?
1 ответ
Это ошибка. RNN
слой реализует __call__
так что тензоры в initial_state
могут быть собраны в экземпляр модели. Тем не менее Bidirectional
обертка не реализовала это. Так что топологическая информация о initial_state
тензор отсутствует и случаются странные ошибки.
Я не знал об этом, когда я осуществлял initial_state
за Bidirectional
, Это надо исправить сейчас, после этого пиара. Вы можете установить последнюю ветку master на GitHub, чтобы исправить это.