Keras Val_acc хорош, но прогноз для тех же данных плох
Я использую Keras для классификации двух классов CNN. Во время тренировки мой val_acc выше 95 процентов. Но когда я прогнозирую результат для тех же данных проверки, акк составляет менее 60 процентов, это вообще возможно? Это мой код:
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.preprocessing import image
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(1337) # for reproducibility
%matplotlib inline
img_width, img_height = 230,170
train_data_dir = 'data/Train'
validation_data_dir = 'data/Validation'
nb_train_samples = 13044
nb_validation_samples = 200
epochs =14
batch_size = 32
if K.image_data_format() == 'channels_first':
input_shape = (1, img_width, img_height)
else:
input_shape = (img_width, img_height, 1)
model = Sequential()
model.add(Convolution2D(32, (3, 3),data_format='channels_first' , input_shape=(1,230,170)))
convout1 = Activation('relu')
model.add(convout1)
convout2 = MaxPooling2D(pool_size=(2,2 ), strides= None , padding='valid', data_format='channels_first')
model.add(convout2)
model.add(Convolution2D(32, (3, 3),data_format='channels_first'))
convout3 = Activation('relu')
model.add(convout3)
model.add(MaxPooling2D(pool_size=(2, 2), data_format='channels_first'))
model.add(Convolution2D(64, (3, 3),data_format='channels_first'))
convout4 = Activation('relu')
model.add(convout4)
convout5 = MaxPooling2D(pool_size=(2, 2), data_format='channels_first')
model.add(convout5)
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
train_datagen = ImageDataGenerator(rescale=1. / 255,
shear_range=0,
zoom_range=0.2,
horizontal_flip=False,
data_format='channels_first')
test_datagen = ImageDataGenerator(rescale=1. / 255,
data_format='channels_first')
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary',
color_mode= "grayscale",
shuffle=True
)
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary',
color_mode= "grayscale",
shuffle=True
)
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size,
shuffle=True
)
Эпоха 37/37
407/407 [==============] - 1775 с 4 с / шаг - потеря: 0,12 - соотв: 0,96 - val_loss: 0,02 - val_acc: 0,99
#Prediction:
test_data_dir='data/test'
validgen = ImageDataGenerator(horizontal_flip=False, data_format='channels_first')
test_gen = validgen.flow_from_directory(
test_data_dir,
target_size=(img_width, img_height),
batch_size=1,
class_mode='binary',
shuffle=False,
color_mode= "grayscale")
preds = model.predict_generator(test_gen)
В приведенном ниже выводе около 7 изображений относятся к классу 0. Я попробовал то же самое для всех 100 изображений данных проверки класса 0, и только 15 изображений были предсказаны как класс 0, а остальные были предсказаны как класс 1
Found 10 images belonging to 1 classes.
[[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 0.]
[ 0.]
[ 1.]]
1 ответ
Вы не масштабируете свои тестовые изображения на 1./255, как на тренировочных и проверочных изображениях. В идеале статистика ваших тестовых данных должна быть аналогична тренировочным данным.
Итак, я решил опубликовать ответ, который я опубликовал в Quora, но с основной частью, как было рекомендовано. У меня тоже была похожая проблема, как эта, и я надеюсь, что мой ответ может помочь кому-то еще. Я решил исследовать Интернет и наткнулся на этот ответ cjbayron.
Что помогло мне решить аналогичную проблему, так это то, что в моем коде для обучения модели было следующее:
import keras
import os
from keras import backend as K
import tensorflow as tf
import random as rn
import numpy as np
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(70)
rn.seed(70)
tf.set_random_seed(70)
/******* code for my model ******/
#very important here to save session after completing model.fit
model.fit_generator(train_batches, steps_per_epoch=4900, validation_data=valid_batches,validation_steps=1225, epochs=40, verbose=2, callbacks=callbacks_list)
saver = tf.train.Saver()
sess = keras.backend.get_session()
saver.save(sess, 'gdrive/My Drive/KerasCNN/model/keras_session/session.ckpt')
сохраненный сеанс также сгенерирует следующие файлы:
- / Keras_session/ контрольно-пропускной пункт
- /keras_session/session.ckpt.data-00000-of-00001
- /keras_session/session.ckpt.index
- /keras_session/session.ckpt.meta
Я также скачал все эти файлы с моего Google Диска и поместил их в локальный каталог. Вы можете заметить, что, похоже, нет файла с именем session.ckpt, но он используется в saver.restore(). Это нормально. Tensorflow вроде работает. Это не принесет ошибки.
Во время model.load_model()
Итак, в моем Pycharm я загрузил модель следующим образом:
model=load_model('C:\\Users\\Username\\PycharmProjects\\MyProject\\mymodel\\mymodel.h5')
saver = tf.train.Saver()
sess = keras.backend.get_session()
saver.restore(sess,'C:\\Users\\Username\\PycharmProjects\\MyProject\\mymodel\\keras_session\\session.ckpt')
/***** then predict the images as you wish ******/
pred = model.predict_classes(load_image(os.path.join(test_path, file)))
Важно разместить код восстановления, как показано, т.е. после загрузки модели. Как только я это сделал, я попытался предсказать те же изображения, которые я использовал для обучения и проверки, и на этот раз модель ошибочно предсказала около 2 изображений в классе. Теперь я был уверен, что с моей моделью все в порядке, и я пошел вперед, чтобы предсказать мои тестовые изображения, то есть изображения, которых он не видел раньше, и они работали очень хорошо.