Несовместимая форма ввода
У меня есть модель CNN-LSTM, где модель CNN принимает в качестве входных данных с формой (None, 301,4,1) и выводит данные с формой (None, 606) . Чтобы адаптировать вывод cnn к входу LSTM , я добавил слой TimeDistributed, где он вызывает модель CNN каждый размер окна =100, поэтому входная форма этого слоя =(None, 100,301,4,1) , а затем у нас есть несколько сложенных слоев LSTM .
Это архитектура модели CNN :
Это архитектура модели LSTM :
Код для этой архитектуры следующий:
input_layer1=Input(shape=(301,4,1))
...
merge_layer=Concatenate(axis=1)([global_max_pooling, lambda_14])
cnn_model = Model(inputs= input_layer1, outputs=merge_layer) cnn_model.compile(optimizer=RMSprop(),loss="mean_squared_error",metrics=['mse', 'mae'])
input_lstm = Input(shape=(100,301,4,1))
cnn_output = TimeDistributed(cnn_model)(input_lstm)
...
output_layer=Dense(1,activation="linear")(lstm3)
cnn_lstm_model = Model(inputs= input_lstm, outputs=output_layer)
cnn_lstm_model.compile(optimizer=RMSprop(),loss="mean_squared_error",metrics=['mse', 'mae'])
Потом сохранил только модель cnn_lstm_model.
Для обучения это мой код:
batchsize=100
epoch=20
cnn_lstm_model.fit(train_data_force_temp_X,data_Y,
batch_size=batchsize,
epochs=epoch,
verbose=1,
shuffle=True,
validation_data=(test_data_force_temp_X,test_Y),
callbacks=[TensorBoard(log_dir="./CNN_LSTM")])
Где train_data_force_temp_X.shape =(1960, 301, 4, 1), PS: 1960 - количество образцов.
Но у меня такая проблема:
ValueError: вход 0 несовместим со слоем model_1: ожидаемая форма =(Нет, 100, 301, 4, 1), найденная форма =(Нет, 301, 4, 1)
Я понял, что передал неправильную форму в cnn_lstm_model, но я подумал, что он сначала передаст данные в модель cnn, которая имеет форму =(None, 301, 4, 1), а затем для каждых 100 выходов CNN он будет вызывать время распределенный слой и продолжаю процесс. Кажется, я неправильно понял процесс.
Итак, мой вопрос:
мне нужно сначала запустить данные в модель cnn, сделать прогноз, а затем использовать эти выходные данные в качестве входных данных для модели cnn_lstm?
Как исправить тренировочный процесс?
Заранее благодарю за помощь.