Как преобразовать список тензора Torch с градиентом в тензор
У меня есть переменная pts, имеющая форму [batch, ch, h, w]. Это тепловая карта, и я хочу преобразовать ее во вторую координату. Цель: pts_o = heatmap_to_pts(pts), где pts_o будет [batch, ch, 2]. Я написал эту функцию до сих пор,
def heatmap_to_pts(self, pts): <- pts [batch, 68, 128, 128]
pt_num = []
for i in range(len(pts)):
pt = pts[i]
if type(pt) == torch.Tensor:
d = torch.tensor(128) * get the
m = pt.view(68, -1).argmax(1) * indices
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1) * from heatmaps
pt_num.append(indices.type(torch.DoubleTensor) ) <- store the indices in a list
b = torch.Tensor(68, 2) * trying to convert
c = torch.cat(pt_num, out=b) *error* * a list of tensors with grad
c = c.reshape(68,2) * to a tensor like [batch, 68, 2]
return c
Ошибка говорит: "cat(): функции с аргументами out=... не поддерживают автоматическое дифференцирование, но для одного из аргументов требуется градиент". Он не может выполнять операции, потому что тензорам в pt_num требуется grad".
Как мне преобразовать этот список в тензор?
1 ответ
Ошибка говорит:
cat(): функции с аргументами out=... не поддерживают автоматическое дифференцирование, но для одного из аргументов требуется grad.
Это означает, что вывод таких функций, как
torch.cat()
который как
out=
kwarg нельзя использовать в качестве входных данных для механизма autograd (который выполняет автоматическое дифференцирование).
Причина в том, что тензоры (в вашем списке Python
pt_num
) имеют разные значения для
requires_grad
атрибут, т.е. некоторые тензоры имеют
requires_grad=True
в то время как некоторые из них
requires_grad=False
.
В вашем коде следующая строка (логически) проблематична:
c = torch.cat(pt_num, out=b)
Возвращаемое значение
torch.cat()
, независимо от того, используете ли вы
out=
kwarg или нет, это конкатенация тензоров по указанной размерности.
Итак, тензор
c
уже является объединенной версией отдельных тензоров в
pt_num
. С помощью
out=b
избыточный. Таким образом, вы можете просто избавиться от
out=b
и все должно быть хорошо.
c = torch.cat(pt_num)