Будет ли pytorch корректировать порядок кода в соответствии с реальной ситуацией при обучении кода?

Недавно я столкнулся с проблемой при запуске исходного кода (здесь) по метаобучению: я хочу увидеть форму ввода, прежде чем передавать его в модель:

      print(eps, episode_x.shape, train_input.shape)

Ниже будет исходный авторский код, который используется для вывода проигрыша этого цикла:

      logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')

Весь код такой:

      for eps, (episode_x, episode_y) in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
        
        # **this is the code i added**
        print(eps, episode_x.shape, train_input.shape)

        train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
        test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()

        # **The code used by the original author to output the loss of this loop**
        logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')

        # Meta-validation
        if eps % args.val_freq == 0 and eps != 0:
            save_ckpt(eps, metalearner, optim, args.save)
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
            if acc > best_acc:
                best_acc = acc
                logger.loginfo("* Best accuracy so far *\n")

    logger.loginfo("Done")

Таким образом, само собой разумеется, что код должен сначала напечатать форму ввода, а затем вывести потерю этого цикла, но мой вывод во время выполнения кода таков:12

Это меня очень смущает, результат кода выглядит так, будто он выполняет полные 50 циклов, прежде чем напечатать то, что я хочу вывести, и сначала печатает потерю, а затем форму. Я посмотрел на код, и подробное изучение показывает, что такие поведения, как многопоточность, не прописаны в коде. Почему возникает такой результат?

0 ответов

Другие вопросы по тегам