PyTorch: вычисление внимания с эффективным использованием памяти в сети графового внимания (GAT)
В настоящее время я разрабатываю вариант оригинальной сети внимания графа (GAT), следуя этой реализации .
Моя цель - включить краевые функции в качестве дополнительных входных данных (единственное отличие). Матрицы смежности представлены не как разреженная матрица NxN, а как сжатая версия, которая отслеживает только связанные узлы.
Оценка внимания вычисляется путем объединения функций узлов и ребер. Следовательно, поскольку я работаю с пакетами графов, имеющих 200 узлов, это приводит к созданию слишком больших матриц, которые невозможно загрузить даже на GPU RTX 3090 24 ГБ.
Это мой код:
linear_transformed_nodes_repeated = linear_transformed_nodes.repeat(n_nodes, 1)
linear_transformed_nodes_repeated_interleave = linear_transformed_nodes.repeat_interleave(n_nodes, dim=0)
#Node concatenation
linear_transformed_nodes_concat = torch.cat([linear_transformed_nodes_repeated_interleave, linear_transformed_nodes_repeated], dim=-1)
#Each concatenation is now repeated n_edges times becuase will be concatenated with every edge
inear_transformed_nodes_concat = linear_transformed_nodes_concat.repeat_interleave(n_edges, dim = 0)
linear_transformed_edges_repeated = linear_transformed_edges.repeat(n_nodes * n_nodes,1)
#Node-Node-Edge concatenation
nodes_edge_concatenation = torch.cat([linear_transformed_nodes_concat, linear_transformed_edges_repeated], dim=1)
#Reshape the matrix so that A[x][y][z] contains the concatenation between nodes x, y and edge z
nodes_edge_concatenation = nodes_edge_concatenation.view(n_nodes, n_nodes, n_edges, self.out_features_nodes * 2 + self.out_features_edges)
e = self.activation(self.attn_nodes(nodes_edge_concatenation))
a = self.softmax(e)
Основная проблема этой реализации — размер матриц конкатенации, например, графа с