Как преобразовать список тензора Torch с градиентом в тензор

У меня есть переменная pts, имеющая форму [batch, ch, h, w]. Это тепловая карта, и я хочу преобразовать ее во вторую координату. Цель: pts_o = heatmap_to_pts(pts), где pts_o будет [batch, ch, 2]. Я написал эту функцию до сих пор,

def heatmap_to_pts(self, pts):  <- pts [batch, 68, 128, 128]
    
    pt_num = []
    
    for i in range(len(pts)):
        
        pt = pts[i]
        if type(pt) == torch.Tensor:

            d = torch.tensor(128)                                                   * get the   
            m = pt.view(68, -1).argmax(1)                                           * indices
            indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)  * from heatmaps
        
            pt_num.append(indices.type(torch.DoubleTensor) )   <- store the indices in a list

    b = torch.Tensor(68, 2)                   * trying to convert
    c = torch.cat(pt_num, out=b) *error*      * a list of tensors with grad
    c = c.reshape(68,2)                       * to a tensor like [batch, 68, 2]

    return c

Ошибка говорит: "cat(): функции с аргументами out=... не поддерживают автоматическое дифференцирование, но для одного из аргументов требуется градиент". Он не может выполнять операции, потому что тензорам в pt_num требуется grad".

Как мне преобразовать этот список в тензор?

1 ответ

Решение

Ошибка говорит:

cat(): функции с аргументами out=... не поддерживают автоматическое дифференцирование, но для одного из аргументов требуется grad.

Это означает, что вывод таких функций, как torch.cat() который как out= kwarg нельзя использовать в качестве входных данных для механизма autograd (который выполняет автоматическое дифференцирование).

Причина в том, что тензоры (в вашем списке Python pt_num) имеют разные значения для requires_grad атрибут, т.е. некоторые тензоры имеют requires_grad=True в то время как некоторые из них requires_grad=False.

В вашем коде следующая строка (логически) проблематична:

c = torch.cat(pt_num, out=b) 

Возвращаемое значение torch.cat(), независимо от того, используете ли вы out= kwarg или нет, это конкатенация тензоров по указанной размерности.

Итак, тензор c уже является объединенной версией отдельных тензоров в pt_num. С помощью out=bизбыточный. Таким образом, вы можете просто избавиться от out=b и все должно быть хорошо.

c = torch.cat(pt_num)
Другие вопросы по тегам