keras Двунаправленный слой с использованием 4-х мерных данных

Я разрабатываю модель Keras для классификации на основе данных статьи.

У меня есть данные с 4 измерения следующим образом

[batch, article_num, word_num, word embedding size]

и я хочу, чтобы каждый (word_num, вложение слова) данные в керас двунаправленный слой

чтобы получить результат с 3-мя измерениями следующим образом.

[batch, article_num, bidirectional layer output size]

когда я попытался передать 4 измерения данных для тестирования, как это

inp = Input(shape=(article_num, word_num, ))
# dims = [batch, article_num, word_num]

x = Reshape((article_num * word_num, ), input_shape = (article_num, word_num))(inp)
# dims = [batch, article_num * word_num]

x = Embedding(word_num, word_embedding_size, input_length = article_num * word_num)(x)
# dims = [batch, article_num * word_num, word_embedding_size]

x = Reshape((article_num , word_num, word_embedding_size), 
             input_shape = (article_num * word_num, word_embedding_size))(x)
# dims = [batch, article_num, word_num, word_embedding_size]

x = Bidirectional(CuDNNLSTM(50, return_sequences = True), 
                  input_shape=(article_num , word_num, word_embedding_size))(x)

и я получил ошибку

ValueError: Input 0 is incompatible with layer bidirectional_12: expected ndim=3, found ndim=4

Как я могу достичь этого?

1 ответ

Решение

Если вы не хотите, чтобы это коснулось article_num измерение, вы можете попробовать использовать TimeDistributed обертка. Но я не уверен, что это будет совместимо с двунаправленными и другими вещами.

inp = Input(shape=(article_num, word_num))    

x = TimeDistributed(Embedding(word_num, word_embedding_size)(x))

#option 1
#x1 shape : (batch, article_num, word_num, 50)
x1 = TimeDistributed(Bidirectional(CuDNNLSTM(50, return_sequences = True)))(x)

#option 2
#x2 shape : (batch, article_num, 50)
x2 = TimeDistributed(Bidirectional(CuDNNLSTM(50)))(x)

подсказки:

  • Не использовать input_shape везде, вам нужно только на Input тензор.
  • Вам, вероятно, не нужно ничего менять, если вы также используете TimeDistributed в встраивании.
  • Если не хочешь word_num в последнем измерении используйте return_sequences=False,
Другие вопросы по тегам