Ошибка при преобразовании массива изображений во время обучения CNN с использованием numpy
Я пытаюсь обучить модель на некотором наборе изображений. Но во время обучения я получаю следующую ошибку:
ValueError: could not broadcast input array from shape (64,64,3) into shape (64,64)
Я изменил размеры всех изображений, чтобы сформировать (64,64,3), используя tflearn.data_utils image_preloader
функция. Я не понимаю, что я делаю здесь неправильно
Вот мой код:
IMAGE_SIZE = 64
NUM_CHANNEL = 3
#Importing data
X_train, Y_train = image_preloader(TRAIN_DATA, image_shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL),mode='file', categorical_labels=True,normalize=True)
X_test, Y_test = image_preloader(TEST_DATA, image_shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL),mode='file', categorical_labels=True,normalize=True)
X = tf.placeholder(tf.float32,shape=[None, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL], name='input_image')
#input class
Y_ = tf.placeholder(tf.float32,shape=[None, NUM_CLASS], name='input_class')
Это основной цикл обучения:
previous_batch = 0
start_time = time.time()
for i in range(epoch):
#batch wise training
if previous_batch >= len(X_train) : #total --> total number of training images
previous_batch = 0
current_batch = previous_batch + batch_size
if current_batch > len(X_train) :
current_batch = len(X_train)
print("Prev =", previous_batch, "Curr =", current_batch)
x_input = X_train[previous_batch : current_batch]
print("x_input length =", len(x_input))
x_images = np.reshape(np.array(x_input), [batch_size, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])
y_input = Y_train[previous_batch : current_batch]
y_label = np.reshape(np.array(y_input), [batch_size, NUM_CLASS])
previous_batch = previous_batch + batch_size
_, loss = sess.run([train_step, cross_entropy], feed_dict = {X: x_images, Y_: y_label})
if i % 500 == 0:
n = 50 #number of test samples
x_test_images = np.reshape(np.array(X_test[0 : n]), [n, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])
y_test_labels = np.reshape(np.array(Y_test[0 : n]), [n, NUM_CLASS])
Accuracy = sess.run(accuracy, feed_dict = {X: x_test_images, Y_: y_test_labels})
print("Iteration no : {}, Accuracy : {}, Loss : {}" .format(i, Accuracy, loss))
saver.save(sess, save_path, global_step = i)
elif i % 100 == 0:
print("Iteration no : {} Loss : {}" .format(i, loss))
saver.save(sess, save_path)
print("Time required = %f sec" % (time.time() - start_time))
Я получаю вышеуказанную ошибку в строке с кодом:
x_test_images = np.reshape(np.array(X_test[0 : n]), [n, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])