Torch C++: получение значения тензора int с помощью *.data<int> ()

В версии Libtorch для C++ я обнаружил, что могу получить значение тензора с плавающей запятой *tensor_name[0].data<float>()в котором вместо 0 Я могу использовать любой другой действительный индекс. Но когда я определил int тензор, добавив опцию at::kInt в создании тензора, я не могу использовать эту структуру, чтобы получить значение тензора, то есть что-то вроде *tensor_name[0].data<at::kInt>() или же *tensor_name[0].data<int>() не работает и отладчик продолжает говорить, что Couldn't find method at::Tensor::data<at::kInt> или же Couldn't find method at::Tensor::data<int>, Я могу получить значения по auto value_array = tensor_name=accessor<int,1>(), но это было проще в использовании *tensor_name[0].data<int>(), Можете ли вы дать мне знать, как я могу использовать data<>() чтобы получить значение int тензор?

У меня тоже такая же проблема с bool тип.

Спасибо афшин

1 ответ

Решение

Использование item<dtype>() вытащить скаляр из тензора.

int main() {
  torch::Tensor tensor = torch::randint(20, {2, 3});
  std::cout << tensor << std::endl;
  int a = tensor[0][0].item<int>();
  std::cout << a << std::endl;
  return 0;
}

~/l/build ❯❯❯ ./example-app
  3  10   3
  2   5   8
[ Variable[CPUFloatType]{2,3} ]
3

Следующий код печатает 0 (протестировано в Linux со стабильным libtorch):

#include <torch/script.h>
#include <iostream>                                     

int main(int argc, const char* argv[])                  
{
    auto indx = torch::zeros({20},at::dtype(at::kLong));
    std::cout << indx[0].item<long>() << std::endl;

    return 0;
}
Другие вопросы по тегам