9 #include <boost/test/unit_test.hpp>
15 #include <torch/torch.h>
18 std::vector<int64_t> start_vec = {
29 BOOST_CHECK(tensor.options().dtype() == torch::kInt64);
30 BOOST_CHECK(tensor.sizes().size() == 2);
31 BOOST_CHECK(tensor.size(0) == 4);
32 BOOST_CHECK(tensor.size(1) == 2);
34 BOOST_CHECK(tensor[0][0].item<int64_t>() == 0);
35 BOOST_CHECK(tensor[0][1].item<int64_t>() == 1);
37 BOOST_CHECK(tensor[1][0].item<int64_t>() == 1);
38 BOOST_CHECK(tensor[1][1].item<int64_t>() == 2);
40 BOOST_CHECK(tensor[2][0].item<int64_t>() == 2);
41 BOOST_CHECK(tensor[2][1].item<int64_t>() == 3);
43 BOOST_CHECK(tensor[3][0].item<int64_t>() == 3);
44 BOOST_CHECK(tensor[3][1].item<int64_t>() == 4);
46 auto test_vec = Acts::detail::tensor2DToVector<int64_t>(tensor);
48 BOOST_CHECK(test_vec == start_vec);
52 std::vector<float> start_vec = {
63 BOOST_CHECK(tensor.options().dtype() == torch::kFloat32);
64 BOOST_CHECK(tensor.sizes().size() == 2);
65 BOOST_CHECK(tensor.size(0) == 4);
66 BOOST_CHECK(tensor.size(1) == 3);
68 for (
auto i : {0, 1, 2, 3}) {
69 BOOST_CHECK(tensor[
i][0].item<int64_t>() == static_cast<float>(
i));
70 BOOST_CHECK(tensor[
i][1].item<int64_t>() == static_cast<float>(
i));
71 BOOST_CHECK(tensor[
i][2].item<int64_t>() == static_cast<float>(
i));
74 auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
76 BOOST_CHECK(test_vec == start_vec);
80 std::vector<float> start_vec = {
91 using namespace torch::indexing;
92 tensor = tensor.index({Slice{}, Slice{0, None, 2}});
94 BOOST_CHECK(tensor.size(0) == 4);
95 BOOST_CHECK(tensor.size(1) == 2);
97 const std::vector<float> ref_vec = {
106 const auto test_vec = Acts::detail::tensor2DToVector<float>(tensor);
108 BOOST_CHECK(test_vec == ref_vec);