Обновление определенных векторных элементов в PyTorch

У меня есть большой вектор, который я хотел бы обновить. Я обновлю его, добавив смещение для определенных элементов в векторе. Я указываю вектор индексов, которые я хочу обновить (вызвать индексный вектор ix), и для каждого индекса я указываю значение, которое я хочу добавить к этому элементу (вызывать вектор значения vals). Если все элементы вектора индекса уникальны, тогда достаточно следующего кода:

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,2], dtype=torch.long)
vals = torch.tensor([0.2, 0.5], dtype=torch.float)
vec[ix] += vals

Однако это не работает, если в ix, Наивный подход для случая повторных индексов заключается в следующем:

for i in range(len(ix)):
    vec[ix[i]] += vals[i]

Но это плохо масштабируется - это очень медленно, когда ix большой. Есть ли более быстрые способы сделать это? Если бы был быстрый способ суммировать все записи vals которые имеют одинаковый индекс в ixТогда решение должно быть простым.

Обновить:
Я нашел одно решение, которое работает довольно хорошо, описано ниже. Я все еще хотел бы обратную связь для лучших решений.

# get unique indices
ix_unique = torch.unique(ix)

# for each unique index, get sum of all vals with that index
vals_unique = torch.stack([
    torch.sum(torch.where(ix==i, vals, torch.zeros_like(vals))) 
    for i in ix_unique
])

# update vec
vec[ix_unique] += vals_unique

0 ответов

Для случаев, когда вы хотите разрешить несколько обновлений одних и тех же индексов ix, также существует библиотека под названием pytorch_scatter. В таких случаях наличие, например, 3 идентичных знаков в ix приведет к добавлению к этому индексу 3*val.

torch.index_add()

import torch

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,0,2], dtype=torch.long)
vals = torch.tensor([0.2,0.1,0.5], dtype=torch.float)
torch.index_add(vec, 0, ix, vals)

и вы получите

tensor([0.3000, 0.0000, 0.5000, 0.0000])

Ссылка: официальный документ

Другие вопросы по тегам