AssertionError в torch_geometric.nn.GATConv

Я пытаюсь использовать модуль сети графического внимания (GAT) в torch_geometricно продолжай натыкаться AssertionError: Static graphs not supported in 'GATConv'со следующим кодом.

      class GraphConv_sum(nn.Module):
    def __init__(self, in_ch, out_ch, num_layers, block, adj):
        super(GraphConv_sum, self).__init__()
        adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
        self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
        self.g_conv = nn.ModuleList()
        
        self.act = nn.LeakyReLU()

        for n in range(num_layers):
            if n == 0:
                self.g_conv.append(block(in_ch, 16))
            elif n > 0 and n < num_layers - 1:
                self.g_conv.append(block(16, 16))
            else:
                self.g_conv.append(block(16, out_ch))

    def forward(self, x):
        for layer in self.g_conv:
            x = layer(x=x, edge_index=self.edge_index)
            x = self.act(x)
            print(x.shape)
        return x[:, 0, :]

Когда я заменю blockс участием GATConvпосле стандартного цикла обучения возникает эта ошибка (другие слои преобразования, такие как GCNConvили SAGEConvпроблем не было). Я проверил документацию и убедился, что входная форма верна (то же самое и для других конверсионных слоев).

В исходнике есть вот это assert x.dim() == 2, "Static graphs not supported in 'GATConv'"часть в forwardметод, но, по-видимому, размер партии будет играть роль в прямом проходе и x.dim()будет 3. Форма ввода с пакетным размером [1024, 6, 200]. Однако, если я вручную изменю условие утверждения на x.dim() == 3все равно будет выдаваться такая же ошибка, как если бы условие не выполнялось. У меня только высокий уровень понимания GAT, поэтому я могу что-то упустить. В любом случае, у меня есть несколько вопросов из этого

  • Возможны ли какие-либо ошибки реализации с моей стороны, вызвавшие эту ошибку?
  • Для чего это условие утверждения? Что такое статический граф в этом случае?

Буду признателен за любую информацию и помощь!! Спасибо!

1 ответ

Оказывается, из-за расчета веса внимания GATConv не поддерживает несколько матриц признаков и одиночный edge_index. Дополнительная информация: https://github.com/pyg-team/pytorch_geometric/issues/2844 .

Другие вопросы по тегам