Несколько tf.GradientTape на одном уровне дают градиент NaN
Я выполняю этапы обучения метаобучению, в то время как мне нужно выполнить два отдельных градиентных спуска, используя
tf.GradientTape
для support_set и query_set на одном уровне. но оказывается, что градиент, вычисленный для цикла набора запросов, дает
[NaN, NaN,...]
в
print(gradient)
в то время как это было хорошо внутри цикла поддержки. Я понятия не имею, почему это происходит, я попытался добавить
persistent = True
Мой код выглядит следующим образом
for img, label in support_set:
with tf.GradientTape() as train_tape:
preds = model(img,training= True)
train_loss = keras.losses.sparse_categorical_crossentropy(label, preds)
gradients = train_tape.gradient(train_loss, model.trainable_variables)
inner_optimizer.apply_gradients(zip(gradients, model_copy.trainable_weights))
model.load_weights("prev_wieghts.h5")
for img,label in query_set:
with tf.GradientTape() as test_tape:
preds = model_copy(img)
test_loss = keras.losses.sparse_categorical_crossentropy(label, preds)
gradients = test_tape.gradient(test_loss, model.trainable_variables)
print(gradients) # It gives [NaN, NaN]
outter_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
prev_weights = model.save_weights("prev_wieghts.h5")