Ошибки при использовании Pytorch Geometric с TorchDyn
Я пытаюсь преобразовать GCN, написанный с помощью PyTorch Geometric, в нейронный ODE графа с использованием TorchDyn. GCN работает правильно сам по себе, но я пытаюсь использовать следующий код для преобразования в GDE:
t_span = torch.linspace(0, 1, 2)
model_sub = nn.Sequential(DataControl(), MyGCN(...))
ode = NeuralODE(model_sub, sensitivity='adjoint', solver='rk4', solver_adjoint='dopri5', atol_adjoint=1e-4, rtol_adjoint=1e-4).to('cuda:0')
model = Learner(t_span, ode)
trainer = pl.Trainer(gpus=1, precision=32, limit_train_batches=0.5,
auto_lr_find=True, logger=logger)
trainer.fit(model, train_loader, val_loader)
который выдает следующую ошибку:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/model/main.py", line 30, in <module>
trainer.fit(model, train_loader, val_loader)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
self._dispatch()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
return self._run_train()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
self._evaluation_loop.run()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 123, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 215, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/home/ubuntu/Desktop/model/model_ode.py", line 102, in validation_step
return self.model_step(val_batch, batch_idx, 'val')
File "/home/ubuntu/Desktop/model/model_ode.py", line 54, in model_step
t_eval, y_hat = self(batch)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/Desktop/model/model_ode.py", line 28, in forward
return self.model(x)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 92, in forward
x, t_span = self._prep_integration(x, t_span)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 88, in _prep_integration
module.u = x[:, excess_dims:].detach()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 154, in __getitem__
return self.index_select(idx)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in index_select
return [self.get_example(i) for i in idx]
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in <listcomp>
return [self.get_example(i) for i in idx]
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 96, in get_example
data = separate(
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 40, in separate
data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 85, in _separate
start, end = slices[idx], slices[idx + 1]
TypeError: unsupported operand type(s) for +: 'slice' and 'int'
The
Learner
определяется в соответствии с форматом, описанным в кратком руководстве по TorchDyn, и
train_loader
а также
val_loader
из
torch_geometric.dataloader
, содержащий
torch_geometric.batch
объекты, распаковываемые
forward
функции GCN следующим образом:
node_feats = data.x # torch.Tensor
edge_index = data.edge_index # torch.Tensor
graph_feats = data.graph_feats # torch.Tensor
node_feats = gcn_layers(node_feats, edge_index, graph_feats)
return node_feats
Хотя ошибку выдает
torch_geometric
проблема явно связана с
torchdyn
так как GCN работает правильно в изоляции. Буду признателен за любую помощь в отладке этой ошибки.