Потеря метаобучения не уменьшается
Когда я использую набор данных CIFAR-FS для обучения простой модели метаобучения, код работает, но потери не уменьшаются, а модель не сходится. Как мне заставить код работать успешно?
Это мои коды:
from tensorflow.keras.models import Model
from tensorflow.keras import losses,layers
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten,Input
import numpy as np
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.enable_eager_execution()
class MAML:
def __init__(self,input_shape,num_classes):
self.input_shape = input_shape
self.num_classes = num_classes
self.meta_model = self.get_model()
def get_model(self):
inputs = Input(shape=(self.input_shape))
out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(inputs)
out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
out = layers.Conv2D(64, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
out = layers.Conv2D(64, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
out = Flatten()(out)
out = Dense(32, activation="relu")(out)
out = Dense(self.num_classes, activation='softmax')(out)
model = Model(inputs=inputs,outputs=out)
return model
def train_on_batch(self,data,inner_step,inner_optimizer,outer_optimizer=None):
mean_loss = []
mean_acc = []
meta_support_image, meta_support_label, meta_query_image, meta_query_label = next(data)
for support_image_one, support_label_one,query_image_one, query_label_one in zip(meta_support_image, meta_support_label, meta_query_image, meta_query_label):
meta_weight = self.meta_model.get_weights()
for _ in range(inner_step):
with tf.GradientTape() as tape_sup:
logits = self.meta_model(support_image_one,training=True)
loss = losses.sparse_categorical_crossentropy(support_label_one, logits,from_logits=True)
loss = tf.reduce_mean(loss)
grads_sup = tape_sup.gradient(loss, self.meta_model.trainable_variables)
inner_optimizer.apply_gradients(zip(grads_sup, self.meta_model.trainable_variables))
with tf.GradientTape() as tape_que:
logits = self.meta_model(query_image_one,training=True)
loss = losses.sparse_categorical_crossentropy(query_label_one, logits,from_logits=True)
loss = tf.reduce_mean(loss)
logits = tf.argmax(logits,axis=-1)
acc = tf.equal(logits,query_label_one)
acc = tf.reduce_mean(tf.cast(acc,tf.float32))
mean_acc.append(acc)
mean_loss.append(loss)
self.meta_model.set_weights(meta_weight)
if outer_optimizer:
grads_que = tape_que.gradient(loss, self.meta_model.trainable_variables)
outer_optimizer.apply_gradients(zip(grads_que,self.meta_model.trainable_variables))
loss_m = tf.reduce_mean(mean_loss)
acc_m = tf.reduce_mean(mean_acc)
return loss_m,acc_m,self.meta_model.get_weights()
класс MAMLDataloader:
def __init__(self,data_path,num_tasks,n_way = 3,k_shot = 10,q_query = 1):
self.dataset = []
for cls_filename in os.listdir(data_path):
x = os.path.join(data_path, cls_filename)
self.dataset.append(x)
self.num_tasks = num_tasks
self.n_way = n_way
self.k_shot = k_shot
self.q_query = q_query
def get_one_task_data(self):
img_dirs = random.sample(self.dataset,self.n_way)
support_data = []
query_data = []
support_image = []
support_label = []
query_image = []
query_label = []
for label,images_dir in enumerate(img_dirs):
images = []
for image in os.listdir(images_dir):
image = os.path.join(images_dir,image)
images.append(image)
images = random.sample(images,self.k_shot+self.q_query)
#read support set
for img_path in images[:self.k_shot]:
image = cv.imread(img_path,cv.IMREAD_COLOR)
image = image.astype("float32")
support_data.append((image,label))
#read query set
for img_path in images[self.k_shot:]:
image = cv.imread(img_path,cv.IMREAD_COLOR)
image = image.astype("float32")
query_data.append((image, label))
#shuffle support set
random.shuffle(support_data)
for data in support_data:
support_image.append(data[0])
support_label.append(data[1])
#shuffle query set
random.shuffle(query_data)
for data in query_data:
query_image.append(data[0])
query_label.append(data[1])
return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label)
def get_one_epoch(self):
while True:
batch_support_image = []
batch_support_label = []
batch_query_image = []
batch_query_label = []
for _ in range(self.num_tasks):
support_image, support_label, query_image, query_label = self.get_one_task_data()
batch_support_image.append(support_image)
batch_support_label.append(support_label)
batch_query_image.append(query_image)
batch_query_label.append(query_label)
yield np.array(batch_support_image), np.array(batch_support_label),np.array(batch_query_image), np.array(batch_query_label)
from tensorflow.keras import optimizers
import tensorflow as tf
import numpy as np
from data_loader import MAMLDataloader
from maml import MAML
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.enable_eager_execution()
train_data_dir = "./CIFAR-FS/cifar100/data/"
inner_lr = 0.001
outer_lr = 0.005
imput_shape = (32,32,3)
n_way = 3
num_tasks = 15
epochs = 100
inner_step = 3
train_data = MAMLDataloader(train_data_dir,num_tasks,n_way,k_shot = 5,q_query = 5)
inner_optimizer = optimizers.Adam(inner_lr)
outer_optimizer = optimizers.Adam(outer_lr)
maml = MAML(input_shape=imput_shape,num_classes=n_way)
if __name__ == '__main__':
for e in range(epochs):
print('\nEpoch {}/{}'.format(e + 1, epochs))
loss,acc,weigth = maml.train_on_batch(train_data.get_one_epoch(),inner_step=inner_step,inner_optimizer=inner_optimizer,outer_optimizer=outer_optimizer)
loss = loss.numpy()
acc = acc.numpy()
print("query set loss:{},acc;{}".format(loss,acc))
maml.meta_model.save_weights("maml_3_way.h5")
Потери были такими:
Потеря набора запросов эпохи 16/100: 1,2181113958358765, акк; 0,3333333432674408
Потеря набора запросов эпохи 17/100: 1,2181113958358765, акк; 0,3333333432674408
Потеря набора запросов эпохи 18/100: 1,2181113958358765, акк; 0,3333333432674408
Моя версия Tensorflow 1.14.0.