PyTorch: вычисление внимания с эффективным использованием памяти в сети графового внимания (GAT)

В настоящее время я разрабатываю вариант оригинальной сети внимания графа (GAT), следуя этой реализации .

Моя цель - включить краевые функции в качестве дополнительных входных данных (единственное отличие). Матрицы смежности представлены не как разреженная матрица NxN, а как сжатая версия, которая отслеживает только связанные узлы.

Оценка внимания вычисляется путем объединения функций узлов и ребер. Следовательно, поскольку я работаю с пакетами графов, имеющих 200 узлов, это приводит к созданию слишком больших матриц, которые невозможно загрузить даже на GPU RTX 3090 24 ГБ.

Это мой код:

      linear_transformed_nodes_repeated = linear_transformed_nodes.repeat(n_nodes, 1)
linear_transformed_nodes_repeated_interleave = linear_transformed_nodes.repeat_interleave(n_nodes, dim=0)
#Node concatenation
linear_transformed_nodes_concat = torch.cat([linear_transformed_nodes_repeated_interleave, linear_transformed_nodes_repeated], dim=-1)
#Each concatenation is now repeated n_edges times becuase will be concatenated with every edge
inear_transformed_nodes_concat = linear_transformed_nodes_concat.repeat_interleave(n_edges, dim = 0)
linear_transformed_edges_repeated = linear_transformed_edges.repeat(n_nodes * n_nodes,1)
#Node-Node-Edge concatenation
nodes_edge_concatenation = torch.cat([linear_transformed_nodes_concat, linear_transformed_edges_repeated], dim=1)
#Reshape the matrix so that A[x][y][z] contains the concatenation between nodes x, y and edge z
nodes_edge_concatenation = nodes_edge_concatenation.view(n_nodes, n_nodes, n_edges, self.out_features_nodes * 2 + self.out_features_edges)
e = self.activation(self.attn_nodes(nodes_edge_concatenation))
a = self.softmax(e)

просто линейное преобразование, а является LeakyReLU. После этого я просто создаю вложение нового узла, суммируя его окрестности.

Основная проблема этой реализации — размер матриц конкатенации, например, графа с производит после 4-го преобразования матрицу с формой , при условии, что каждые 200 объектов на узел. Есть ли способ вычислить показатель внимания без вычисления всех этих промежуточных тензоров?

0 ответов

Другие вопросы по тегам