В модели диффузоров нет grad_fn, как обучить модель?

Я сделал диффузионную модель, используя детали из другой упаковки.

      class DiffusionModule(nn.Module):
    def __init__(self, num_inference_steps=10, guidance_scale=7.5):
        super().__init__()
        self.num_inference_steps = num_inference_steps
        self.guidance_scale = guidance_scale
        self.generator = torch.manual_seed(42)
        self.out_H = 512
        self.out_W = 512

        self.vaem = AutoencoderKL.from_pretrained(
            "CompVis/stable-diffusion-v1-4", subfolder="vae"
        )
        self.tokm = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
        self.tenm = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
        self.untm = UNet2DConditionModel.from_pretrained(
            "CompVis/stable-diffusion-v1-4", subfolder="unet"
        )
        self.scheduler = LMSDiscreteScheduler.from_pretrained(
            "CompVis/stable-diffusion-v1-4", subfolder="scheduler"
        )

        self.default_unc_inp = None

    def forward(self, x):
        if self.default_unc_inp is None:
            unc_inp = self.tokm(
                [""] * x.shape[0],
                padding="max_length",
                max_length=self.tokm.model_max_length,
                return_tensors="pt",
            )
            self.default_unc_inp = unc_inp
        else:
            unc_inp = self.default_unc_inp
        unc_emb = self.tenm(unc_inp.input_ids.to(x.device))[0]
        con_emb = torch.cat([unc_emb, x])

        latents = torch.randn(
            (x.shape[0], self.untm.in_channels, self.out_H // 8, self.out_W // 8),
            generator=self.generator,
        ).to(x.device)

        self.scheduler.set_timesteps(self.num_inference_steps)
        latents = latents * self.scheduler.init_noise_sigma

        for t in self.scheduler.timesteps:
            latent_inp = torch.cat([latents] * 2).to(x.device)
            latent_inp = self.scheduler.scale_model_input(latent_inp, t)
            with torch.no_grad():
                noise_pred = self.untm(
                    latent_inp, t, encoder_hidden_states=con_emb.to(x.device)
                ).sample
            noise_pred_unc, noise_pred_x = noise_pred.chunk(2)
            noise_pred = noise_pred_unc + self.guidance_scale * (
                noise_pred_x - noise_pred_unc
            )
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        latents = 1 / 0.18215 * latents

        with torch.no_grad():
            imgs = self.vaem.decode(latents).sample
        return imgs

Я хотел бы тренироватьсяbrain_module, который встраивает сигналы ЭЭГ в нужную форму, но получаю ошибку

      ---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-a304a6750226> in <cell line: 5>()
     15         loss = criterion(generated_images, images)
     16 
---> 17         loss.backward()
     18         optimizer.step()
     19 

1 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

когда я тренируюсь с использованием этого кода -

      optimizer = optim.Adam(brain_module.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training loop
for epoch in range(1):
    for eegs, images in data_loader:
        eegs = eegs.float().to(torch_device)
        images = images.float().to(torch_device)
        
        optimizer.zero_grad()
    
        embeddings = brain_module(eegs)
        generated_images = diffusion_module(embeddings)

        loss = criterion(generated_images, images)

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{10}, Loss: {loss.item()}")```

Can someone please tell me how to make the training loop in this case?


I was expecting the training to work, but I don't have much experience with pytorch autograd.

0 ответов

Другие вопросы по тегам