Ускорение операций pytorch для исключения пользовательских сообщений
Я пытаюсь реализовать удаление сообщений в моей пользовательской свертке MessagePassing в PyTorch Geometric. Выпадение сообщения состоит из случайного игнорирования p% ребер в графе. Моя идея состояла в том, чтобы случайным образом удалить p% из них из ввода в
forward()
.
В
edge_index
является тензором формы
(2, num_edges)
где 1-е измерение - это идентификатор узла "от", а 2-е - это идентификатор узла "до". Итак, я подумал, что могу сделать, это выбрать случайную выборку из
range(N)
а затем используйте его, чтобы замаскировать остальные индексы:
def forward(self, x, edge_index, edge_attr=None):
if self.message_dropout is not None:
# TODO: this is way too slow (4-5 times slower than without it)
# message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
edge_index_to_use = edge_index[:, random_keep_inx]
edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
else:
edge_index_to_use = edge_index
edge_attr_to_use = edge_attr
...
Однако он слишком медленный, он заставляет эпоху идти 5 минут вместо 1 минуты (в 5 раз медленнее). Есть ли более быстрый способ сделать это в PyTorch?
Изменить: узким местом, по-видимому, является
random.sample()
вызов, а не маскировка. Поэтому я думаю, что я должен просить о более быстрых альтернативах этому.