Несоответствие форм, 2D-ввод и 2D-метки

Я хочу создать нейронную сеть, которая, легко говоря, создает изображение из изображения (оттенки серого). Я успешно создал набор данных из 3200 примеров изображений ввода и вывода (меток). (Я знаю, что набор данных должен быть больше, но сейчас это не проблема)

Вход [Xin] имеет размер (3200, 50, 30), поскольку он составляет 50*30 пикселей. Выход [yout] имеет размер (3200, 30, 20), поскольку он составляет 30*20 пикселей.

Я хочу опробовать полностью подключенную сеть (позже CNN). Построенная из полностью подключенной модели выглядит так:

# 5 Create Model
model = tf.keras.models.Sequential()                                
model.add(tf.keras.layers.Flatten())                                
model.add(tf.keras.layers.Dense(256, activation=tf.nn.relu))        
model.add(tf.keras.layers.Dense(30*20, activation=tf.nn.relu))    


#compile the model
model.compile(optimizer='adam',                                    
              loss='sparse_categorical_crossentropy',               
              metrics=['accuracy'])                                 

# 6 Train the model
model.fit(Xin, yout, epochs=1)                                      #train the model

После этого я получаю следующую ошибку:

ValueError: Несоответствие формы: форма меток (полученная (19200,)) должна соответствовать форме логитов, за исключением последнего измерения (получено (32, 600)).

Я уже пытался тебя сгладить:

youtflat = yout.transpose(1,0,2).reshape(-1,yout.shape[1]*yout.shape[2])

но это привело к той же ошибке

1 ответ

Решение

Похоже, вы полностью сглаживаете свои метки (yout), то есть теряете размер партии. Если ваш оригинальный yout имеет форму(3200, 30, 20) вы должны изменить его, чтобы он имел форму (3200, 30*20) что равно (3200, 600):

yout = numpy.reshape((3200, 600))

Тогда это должно работать

ПРИМЕЧАНИЕ. Предлагаемое исправление, однако, устраняет только ошибку. Однако я вижу много проблем с вашим методом. Для задачи, которую вы пытаетесь выполнить (получение изображения в качестве вывода), вы не можете использоватьsparse_categorical_crossentropy как потеря и accuracyкак метрики. Вместо этого вы должны использовать "mse" или "mae".

Другие вопросы по тегам