AssertionError в torch_geometric.nn.GATConv
Я пытаюсь использовать модуль сети графического внимания (GAT) в
torch_geometric
но продолжай натыкаться
AssertionError: Static graphs not supported in 'GATConv'
со следующим кодом.
class GraphConv_sum(nn.Module):
def __init__(self, in_ch, out_ch, num_layers, block, adj):
super(GraphConv_sum, self).__init__()
adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
self.g_conv = nn.ModuleList()
self.act = nn.LeakyReLU()
for n in range(num_layers):
if n == 0:
self.g_conv.append(block(in_ch, 16))
elif n > 0 and n < num_layers - 1:
self.g_conv.append(block(16, 16))
else:
self.g_conv.append(block(16, out_ch))
def forward(self, x):
for layer in self.g_conv:
x = layer(x=x, edge_index=self.edge_index)
x = self.act(x)
print(x.shape)
return x[:, 0, :]
Когда я заменю
block
с участием
GATConv
после стандартного цикла обучения возникает эта ошибка (другие слои преобразования, такие как
GCNConv
или
SAGEConv
проблем не было). Я проверил документацию и убедился, что входная форма верна (то же самое и для других конверсионных слоев).
В исходнике есть вот это
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
часть в
forward
метод, но, по-видимому, размер партии будет играть роль в прямом проходе и
x.dim()
будет 3. Форма ввода с пакетным размером [1024, 6, 200]. Однако, если я вручную изменю условие утверждения на
x.dim() == 3
все равно будет выдаваться такая же ошибка, как если бы условие не выполнялось. У меня только высокий уровень понимания GAT, поэтому я могу что-то упустить. В любом случае, у меня есть несколько вопросов из этого
- Возможны ли какие-либо ошибки реализации с моей стороны, вызвавшие эту ошибку?
- Для чего это условие утверждения? Что такое статический граф в этом случае?
Буду признателен за любую информацию и помощь!! Спасибо!
1 ответ
Оказывается, из-за расчета веса внимания GATConv не поддерживает несколько матриц признаков и одиночный edge_index. Дополнительная информация: https://github.com/pyg-team/pytorch_geometric/issues/2844 .