Несколько tf.GradientTape на одном уровне дают градиент NaN

Я выполняю этапы обучения метаобучению, в то время как мне нужно выполнить два отдельных градиентных спуска, используя tf.GradientTapeдля support_set и query_set на одном уровне. но оказывается, что градиент, вычисленный для цикла набора запросов, дает [NaN, NaN,...]в print(gradient)в то время как это было хорошо внутри цикла поддержки. Я понятия не имею, почему это происходит, я попытался добавить persistent = True

Мой код выглядит следующим образом

                  for img, label in support_set:
              with tf.GradientTape() as train_tape:
                preds = model(img,training= True) 
                train_loss = keras.losses.sparse_categorical_crossentropy(label, preds)
              gradients = train_tape.gradient(train_loss, model.trainable_variables)
              inner_optimizer.apply_gradients(zip(gradients, model_copy.trainable_weights))

            model.load_weights("prev_wieghts.h5")

            for img,label in query_set:
              with tf.GradientTape() as test_tape:
                preds = model_copy(img) 
                test_loss = keras.losses.sparse_categorical_crossentropy(label, preds)
            
              gradients = test_tape.gradient(test_loss, model.trainable_variables)
              print(gradients) # It gives [NaN, NaN]
              outter_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
              prev_weights = model.save_weights("prev_wieghts.h5")

0 ответов

Другие вопросы по тегам