16 #include <torch/torch.h>
18 namespace Acts::detail {
26 constexpr
static torch::Dtype
type = torch::kInt64;
31 constexpr
static torch::Dtype
type = torch::kInt32;
36 constexpr
static torch::Dtype
type = torch::kInt16;
41 constexpr
static torch::Dtype
type = torch::kInt8;
46 constexpr
static torch::Dtype
type = torch::kFloat32;
51 constexpr
static torch::Dtype
type = torch::kFloat64;
61 assert(vec.size() % cols == 0);
66 return torch::from_blob(
68 {
static_cast<long>(vec.size() / cols), static_cast<long>(cols)},
opts);
75 assert(tensor.sizes().size() == 2);
82 at::Tensor transformedTensor =
85 std::vector<T> edgeIndex(
86 transformedTensor.template data_ptr<T>(),
87 transformedTensor.template data_ptr<T>() + transformedTensor.numel());