keras Двунаправленный слой с использованием 4-х мерных данных
Я разрабатываю модель Keras для классификации на основе данных статьи.
У меня есть данные с 4 измерения следующим образом
[batch, article_num, word_num, word embedding size]
и я хочу, чтобы каждый (word_num, вложение слова) данные в керас двунаправленный слой
чтобы получить результат с 3-мя измерениями следующим образом.
[batch, article_num, bidirectional layer output size]
когда я попытался передать 4 измерения данных для тестирования, как это
inp = Input(shape=(article_num, word_num, ))
# dims = [batch, article_num, word_num]
x = Reshape((article_num * word_num, ), input_shape = (article_num, word_num))(inp)
# dims = [batch, article_num * word_num]
x = Embedding(word_num, word_embedding_size, input_length = article_num * word_num)(x)
# dims = [batch, article_num * word_num, word_embedding_size]
x = Reshape((article_num , word_num, word_embedding_size),
input_shape = (article_num * word_num, word_embedding_size))(x)
# dims = [batch, article_num, word_num, word_embedding_size]
x = Bidirectional(CuDNNLSTM(50, return_sequences = True),
input_shape=(article_num , word_num, word_embedding_size))(x)
и я получил ошибку
ValueError: Input 0 is incompatible with layer bidirectional_12: expected ndim=3, found ndim=4
Как я могу достичь этого?
1 ответ
Решение
Если вы не хотите, чтобы это коснулось article_num
измерение, вы можете попробовать использовать TimeDistributed
обертка. Но я не уверен, что это будет совместимо с двунаправленными и другими вещами.
inp = Input(shape=(article_num, word_num))
x = TimeDistributed(Embedding(word_num, word_embedding_size)(x))
#option 1
#x1 shape : (batch, article_num, word_num, 50)
x1 = TimeDistributed(Bidirectional(CuDNNLSTM(50, return_sequences = True)))(x)
#option 2
#x2 shape : (batch, article_num, 50)
x2 = TimeDistributed(Bidirectional(CuDNNLSTM(50)))(x)
подсказки:
- Не использовать
input_shape
везде, вам нужно только наInput
тензор. - Вам, вероятно, не нужно ничего менять, если вы также используете
TimeDistributed
в встраивании. - Если не хочешь
word_num
в последнем измерении используйтеreturn_sequences=False
,