Потеря метаобучения не уменьшается

Когда я использую набор данных CIFAR-FS для обучения простой модели метаобучения, код работает, но потери не уменьшаются, а модель не сходится. Как мне заставить код работать успешно?

Это мои коды:

      from tensorflow.keras.models import Model
from tensorflow.keras import losses,layers
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten,Input
import numpy as np
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.enable_eager_execution()

class MAML:
def __init__(self,input_shape,num_classes):

    self.input_shape = input_shape
    self.num_classes = num_classes
    self.meta_model = self.get_model()


def get_model(self):

    inputs = Input(shape=(self.input_shape))
    out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(inputs)
    out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
    out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
    out = layers.Conv2D(64, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
    out = layers.Conv2D(64, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
    out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
    out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
    out = layers.Conv2D(32, kernel_size=[3, 3], padding="SAME", activation=tf.nn.relu)(out)
    out = layers.MaxPooling2D(pool_size=[2, 2], strides=2, padding="SAME")(out)
    out = Flatten()(out)
    out = Dense(32, activation="relu")(out)
    out = Dense(self.num_classes, activation='softmax')(out)
    model = Model(inputs=inputs,outputs=out)
    return model



def train_on_batch(self,data,inner_step,inner_optimizer,outer_optimizer=None):


    mean_loss = []
    mean_acc = []
    meta_support_image, meta_support_label, meta_query_image, meta_query_label = next(data)
    for support_image_one, support_label_one,query_image_one, query_label_one in zip(meta_support_image, meta_support_label, meta_query_image, meta_query_label):
        meta_weight = self.meta_model.get_weights()
        for _ in range(inner_step):
            with tf.GradientTape() as tape_sup:
                logits = self.meta_model(support_image_one,training=True)
                loss = losses.sparse_categorical_crossentropy(support_label_one, logits,from_logits=True)
                loss = tf.reduce_mean(loss)
            grads_sup = tape_sup.gradient(loss, self.meta_model.trainable_variables)
            inner_optimizer.apply_gradients(zip(grads_sup, self.meta_model.trainable_variables))

        with tf.GradientTape() as tape_que:
            logits = self.meta_model(query_image_one,training=True)
            loss = losses.sparse_categorical_crossentropy(query_label_one, logits,from_logits=True)
            loss = tf.reduce_mean(loss)
            logits = tf.argmax(logits,axis=-1)
            acc = tf.equal(logits,query_label_one)
            acc = tf.reduce_mean(tf.cast(acc,tf.float32))
            mean_acc.append(acc)
            mean_loss.append(loss)
        self.meta_model.set_weights(meta_weight)
        if outer_optimizer:
            grads_que = tape_que.gradient(loss, self.meta_model.trainable_variables)
            outer_optimizer.apply_gradients(zip(grads_que,self.meta_model.trainable_variables))

    loss_m = tf.reduce_mean(mean_loss)
    acc_m = tf.reduce_mean(mean_acc)

    return loss_m,acc_m,self.meta_model.get_weights() 

класс MAMLDataloader:

      def __init__(self,data_path,num_tasks,n_way = 3,k_shot = 10,q_query = 1):

    self.dataset = [] 
    for cls_filename in os.listdir(data_path):
        x = os.path.join(data_path, cls_filename)
        self.dataset.append(x)
    self.num_tasks = num_tasks
    self.n_way = n_way
    self.k_shot = k_shot
    self.q_query = q_query


def get_one_task_data(self):

    img_dirs = random.sample(self.dataset,self.n_way)
    support_data = []
    query_data = []
    support_image = []
    support_label = []
    query_image = []
    query_label = []

    for label,images_dir in enumerate(img_dirs):
        images = []
        for image in os.listdir(images_dir):
            image = os.path.join(images_dir,image)
            images.append(image)
        images = random.sample(images,self.k_shot+self.q_query)


        #read support set
        for img_path in images[:self.k_shot]:
            image = cv.imread(img_path,cv.IMREAD_COLOR)
            image = image.astype("float32")
            support_data.append((image,label))

        #read query set
        for img_path in images[self.k_shot:]:
            image = cv.imread(img_path,cv.IMREAD_COLOR)
            image = image.astype("float32")
            query_data.append((image, label))

    #shuffle support set
    random.shuffle(support_data)
    for data in support_data:
        support_image.append(data[0])
        support_label.append(data[1])

    #shuffle query set
    random.shuffle(query_data)
    for data in query_data:
        query_image.append(data[0])
        query_label.append(data[1])

    return  np.array(support_image),  np.array(support_label),  np.array(query_image),  np.array(query_label)



def get_one_epoch(self):

    while True:
        batch_support_image = []
        batch_support_label = []
        batch_query_image = []
        batch_query_label = []

        for _ in range(self.num_tasks):
            support_image, support_label, query_image, query_label = self.get_one_task_data()
            batch_support_image.append(support_image)
            batch_support_label.append(support_label)
            batch_query_image.append(query_image)
            batch_query_label.append(query_label)

        yield np.array(batch_support_image), np.array(batch_support_label),np.array(batch_query_image), np.array(batch_query_label)


from tensorflow.keras import optimizers
import tensorflow as tf
import numpy as np
from data_loader import MAMLDataloader
from maml import MAML
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.enable_eager_execution()

train_data_dir = "./CIFAR-FS/cifar100/data/"


inner_lr = 0.001
outer_lr = 0.005
imput_shape = (32,32,3)
n_way = 3
num_tasks = 15
epochs = 100
inner_step = 3
train_data = MAMLDataloader(train_data_dir,num_tasks,n_way,k_shot = 5,q_query = 5)

inner_optimizer = optimizers.Adam(inner_lr)
outer_optimizer = optimizers.Adam(outer_lr)

maml = MAML(input_shape=imput_shape,num_classes=n_way)


if __name__ == '__main__':

for e in range(epochs):
    print('\nEpoch {}/{}'.format(e + 1, epochs))
    loss,acc,weigth = maml.train_on_batch(train_data.get_one_epoch(),inner_step=inner_step,inner_optimizer=inner_optimizer,outer_optimizer=outer_optimizer)
    loss = loss.numpy()
    acc = acc.numpy()
    print("query set loss:{},acc;{}".format(loss,acc))

maml.meta_model.save_weights("maml_3_way.h5")

Потери были такими:

Потеря набора запросов эпохи 16/100: 1,2181113958358765, акк; 0,3333333432674408

Потеря набора запросов эпохи 17/100: 1,2181113958358765, акк; 0,3333333432674408

Потеря набора запросов эпохи 18/100: 1,2181113958358765, акк; 0,3333333432674408

Моя версия Tensorflow 1.14.0.

0 ответов

Другие вопросы по тегам