Будет ли pytorch корректировать порядок кода в соответствии с реальной ситуацией при обучении кода?
Недавно я столкнулся с проблемой при запуске исходного кода (здесь) по метаобучению: я хочу увидеть форму ввода, прежде чем передавать его в модель:
print(eps, episode_x.shape, train_input.shape)
Ниже будет исходный авторский код, который используется для вывода проигрыша этого цикла:
logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')
Весь код такой:
for eps, (episode_x, episode_y) in enumerate(train_loader):
# episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
# episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
# **this is the code i added**
print(eps, episode_x.shape, train_input.shape)
train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]
# Train learner with metalearner
learner_w_grad.reset_batch_stats()
learner_wo_grad.reset_batch_stats()
learner_w_grad.train()
learner_wo_grad.train()
cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)
# Train meta-learner with validation loss
learner_wo_grad.transfer_params(learner_w_grad, cI)
output = learner_wo_grad(test_input)
loss = learner_wo_grad.criterion(output, test_target)
acc = accuracy(output, test_target)
optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
optim.step()
# **The code used by the original author to output the loss of this loop**
logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')
# Meta-validation
if eps % args.val_freq == 0 and eps != 0:
save_ckpt(eps, metalearner, optim, args.save)
acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
if acc > best_acc:
best_acc = acc
logger.loginfo("* Best accuracy so far *\n")
logger.loginfo("Done")
Таким образом, само собой разумеется, что код должен сначала напечатать форму ввода, а затем вывести потерю этого цикла, но мой вывод во время выполнения кода таков:12
Это меня очень смущает, результат кода выглядит так, будто он выполняет полные 50 циклов, прежде чем напечатать то, что я хочу вывести, и сначала печатает потерю, а затем форму. Я посмотрел на код, и подробное изучение показывает, что такие поведения, как многопоточность, не прописаны в коде. Почему возникает такой результат?