Пользовательская функция сбора данных BoTorch - CUDA не хватает памяти
Мы используем настраиваемую функцию сбора данных как часть гауссовского процесса с использованием BoTorch и GpyTorch. В рамках метода forward у нас есть следующий код:
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
# Fix features
X = fix_features(X, self.fixed_features)
fixed_averages = dict(zip(self.averaging_features, [None for _ in self.averaging_features]))
X = fix_features(X, fixed_averages)
xndim = X.ndim
X = X.unsqueeze(0)
if self.averaging_features is not None:
X = X.repeat(self.num_samples, *tuple(torch.ones(xndim).int()))
X[..., self.averaging_features] = self.samples.unsqueeze(-2).unsqueeze(-2).to(X)
# Calculate average entropy
posterior = self.model.posterior(X)
...
Проблема в том, что когда мы используем нашу настраиваемую функцию сбора данных, мы довольно быстро получаем ошибку CUDA OOM. В целом мы получаем относительно высокие всплески памяти графического процессора из-за кода, когда память не освобождается. Мы предполагаем, что это как-то связано со строками 12 и 13 (X = X.repeat...). Вы хоть представляете, как добиться того же результата с помощью более эффективного использования памяти (GPU)? Кроме того, self.samples находится на процессоре, а X - от графического процессора. Как лучше всего с этим справиться? В настоящее время у нас есть.to(x) в конце строки 13. Спасибо!