Torch C++: получение значения тензора int с помощью *.data<int> ()
В версии Libtorch для C++ я обнаружил, что могу получить значение тензора с плавающей запятой *tensor_name[0].data<float>()
в котором вместо 0
Я могу использовать любой другой действительный индекс. Но когда я определил int
тензор, добавив опцию at::kInt
в создании тензора, я не могу использовать эту структуру, чтобы получить значение тензора, то есть что-то вроде *tensor_name[0].data<at::kInt>()
или же *tensor_name[0].data<int>()
не работает и отладчик продолжает говорить, что Couldn't find method at::Tensor::data<at::kInt>
или же Couldn't find method at::Tensor::data<int>
, Я могу получить значения по auto value_array = tensor_name=accessor<int,1>()
, но это было проще в использовании *tensor_name[0].data<int>()
, Можете ли вы дать мне знать, как я могу использовать data<>()
чтобы получить значение int
тензор?
У меня тоже такая же проблема с bool
тип.
Спасибо афшин
1 ответ
Использование item<dtype>()
вытащить скаляр из тензора.
int main() {
torch::Tensor tensor = torch::randint(20, {2, 3});
std::cout << tensor << std::endl;
int a = tensor[0][0].item<int>();
std::cout << a << std::endl;
return 0;
}
~/l/build ❯❯❯ ./example-app
3 10 3
2 5 8
[ Variable[CPUFloatType]{2,3} ]
3
Следующий код печатает 0
(протестировано в Linux со стабильным libtorch):
#include <torch/script.h>
#include <iostream>
int main(int argc, const char* argv[])
{
auto indx = torch::zeros({20},at::dtype(at::kLong));
std::cout << indx[0].item<long>() << std::endl;
return 0;
}