Сверточный автоэнкодер обучается только на 1 канале
Мои данные имеют форму (100, 2, 2), и мой код учится на канале [:,:, 0], но не на канале [:,:, 1]
Соответствующая часть моего кода
Настроить
self.encoder_input = tf.placeholder(tf.float32, input_shape, name='x')
self.regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
Кодер:
with tf.variable_scope("encoder"):
conv1 = tf.layers.conv2d(self.encoder_input, filters=32, kernel_size=(2, 2),
activation=tf.nn.relu, padding='same', kernel_regularizer=self.regularizer)
mp1 = tf.layers.max_pooling2d(conv1, pool_size=(4, 1), strides=(4, 1))
conv2 = tf.layers.conv2d(mp1, filters=64, kernel_size=(2, 2),
activation=None, padding='same', kernel_regularizer=self.regularizer)
return conv2
где conv2 затем подается в декодер:
def _construct_decoder(self, encoded):
with tf.variable_scope("decoder"):
upsample1 = tf.image.resize_images(encoded, size=(50, 2), method=tf.image.ResizeMethod.BILINEAR)
conv4 = tf.layers.conv2d(inputs=upsample1, filters=32, kernel_size=(2, 2), padding='same',
activation=tf.nn.relu, kernel_regularizer=self.regularizer)
upsample2 = tf.image.resize_images(conv4, size=(100, 2), method=tf.image.ResizeMethod.BILINEAR)
conv5 = tf.layers.conv2d(inputs=upsample2, filters=2, kernel_size=(2, 2), padding='same',
activation=None, kernel_regularizer=self.regularizer)
self.decoder = conv5
Мои потери следующие:
base_loss = tf.losses.mean_squared_error(labels=self.encoder_input, predictions=self.decoder)
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([base_loss] + reg_losses, name="loss")
cost = tf.reduce_mean(loss)
tf.summary.scalar('cost', cost)
optimizer = tf.train.AdamOptimizer(self.lr)
grads = optimizer.compute_gradients(cost)
# Update the weights wrt to the gradient
optimizer = optimizer.apply_gradients(grads)
# Save the grads with tf.summary.histogram
for index, grad in enumerate(grads):
tf.summary.histogram("{}-grad".format(grads[index][1].name), grads[index])
Я знаю, что это не обучение на втором канале, потому что я строю max, min, s.dev и так далее для разницы между фактическим и прогнозируемым для каждого канала. Я не совсем уверен, почему он учится на первом, а не на втором - у кого-нибудь есть идеи?