Ускорение операций pytorch для исключения пользовательских сообщений

Я пытаюсь реализовать удаление сообщений в моей пользовательской свертке MessagePassing в PyTorch Geometric. Выпадение сообщения состоит из случайного игнорирования p% ребер в графе. Моя идея состояла в том, чтобы случайным образом удалить p% из них из ввода в forward().

В edge_indexявляется тензором формы (2, num_edges)где 1-е измерение - это идентификатор узла "от", а 2-е - это идентификатор узла "до". Итак, я подумал, что могу сделать, это выбрать случайную выборку из range(N)а затем используйте его, чтобы замаскировать остальные индексы:

          def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

Однако он слишком медленный, он заставляет эпоху идти 5 минут вместо 1 минуты (в 5 раз медленнее). Есть ли более быстрый способ сделать это в PyTorch?

Изменить: узким местом, по-видимому, является random.sample()вызов, а не маскировка. Поэтому я думаю, что я должен просить о более быстрых альтернативах этому.

0 ответов

Другие вопросы по тегам