19 #include <torch/script.h>
20 #include <torch/torch.h>
22 #ifndef ACTS_EXATRKX_CPUONLY
24 #include <cuda_runtime_api.h>
25 #include <grid/counting_sort.h>
26 #include <grid/find_nbrs.h>
27 #include <grid/grid.h>
28 #include <grid/insert_points.h>
29 #include <grid/prefix_sum.h>
32 using namespace torch::indexing;
36 bool removeDuplicates,
37 bool flipDirections) {
39 if (removeSelfLoops) {
40 torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
41 edges = edges.index({Slice(), selfLoopMask});
45 if (removeDuplicates) {
46 torch::Tensor
mask = edges.index({0}) > edges.index({1});
47 edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
48 edges = std::get<0>(torch::unique_dim(edges, -1,
false));
53 torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
54 torch::Tensor random_cut_flip = 1 - random_cut_keep;
55 torch::Tensor keep_edges =
56 edges.index({Slice(), random_cut_keep.to(
torch::kBool)});
57 torch::Tensor flip_edges =
58 edges.index({Slice(), random_cut_flip.to(
torch::kBool)}).flip({0});
59 edges = torch::cat({keep_edges, flip_edges}, 1);
62 return edges.toType(torch::kInt64);
67 bool flipDirections) {
68 #ifndef ACTS_EXATRKX_CPUONLY
69 const auto device = embedFeatures.device();
71 const int64_t numSpacepoints = embedFeatures.size(0);
72 const int dim = embedFeatures.size(1);
74 const int grid_params_size = 8;
75 const int grid_delta_idx = 3;
76 const int grid_total_idx = 7;
77 const int grid_max_res = 128;
78 const int grid_dim = 3;
81 throw std::runtime_error(
"DIM < 3 is not supported for now.\n");
84 const float radius_cell_ratio = 2.0;
85 const int batch_size = 1;
89 torch::Tensor grid_min;
90 torch::Tensor grid_max;
91 torch::Tensor grid_size;
93 torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
94 torch::Tensor gridParamsCuda =
95 torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
96 torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
97 torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
100 for (
int i = 0;
i < batch_size;
i++) {
101 torch::Tensor allPoints =
102 embedTensor.index({
i, Slice(None, lengths.index({
i}).item().to<
long>()),
103 Slice(None, grid_dim)});
104 grid_min = std::get<0>(allPoints.min(0));
105 grid_max = std::get<0>(allPoints.max(0));
106 gridParamsCuda.index_put_({
i, Slice(None, grid_delta_idx)}, grid_min);
108 grid_size = grid_max - grid_min;
111 r_tensor.index({
i}).item().to<
float>() / radius_cell_ratio;
113 if (cell_size < (grid_size.min().item().to<
float>() / grid_max_res)) {
114 cell_size = grid_size.min().item().to<
float>() / grid_max_res;
117 gridParamsCuda.index_put_({
i, grid_delta_idx}, 1 / cell_size);
119 gridParamsCuda.index_put_({
i, Slice(1 + grid_delta_idx, grid_total_idx)},
120 floor(grid_size / cell_size) + 1);
122 gridParamsCuda.index_put_(
124 gridParamsCuda.index({
i, Slice(1 + grid_delta_idx, grid_total_idx)})
127 if (G < gridParamsCuda.index({
i, grid_total_idx}).item().to<
int>()) {
128 G = gridParamsCuda.index({
i, grid_total_idx}).item().to<
int>();
132 torch::Tensor pc_grid_cnt =
133 torch::zeros({batch_size, G}, device).to(torch::kInt32);
134 torch::Tensor pc_grid_cell =
135 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
136 torch::Tensor pc_grid_idx =
137 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
140 InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
141 pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
143 torch::Tensor pc_grid_off =
144 torch::full({batch_size, G}, 0, device).to(torch::kInt32);
145 torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
148 pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
150 torch::Tensor sorted_points =
151 torch::zeros({batch_size, numSpacepoints, dim}, device)
152 .to(torch::kFloat32);
153 torch::Tensor sorted_points_idxs =
154 torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
156 CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
157 pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
159 auto [indices, distances] = FindNbrsCUDA(
160 sorted_points, sorted_points, lengths.to(torch::kInt64),
161 lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
162 sorted_points_idxs, sorted_points_idxs,
163 gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
164 torch::Tensor positiveIndices = indices >= 0;
166 torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
167 .repeat({1, positiveIndices.size(2), 1})
170 torch::Tensor stackedEdges = torch::stack(
171 {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
176 throw std::runtime_error(
177 "ACTS not compiled with CUDA, cannot run Acts::buildEdgesFRNN");
184 template <
typename T, std::
size_t S>
197 template <std::
size_t Dim>
200 for (
auto i = 0ul;
i < Dim; ++
i) {
201 s += (a[
i] - b[
i]) * (a[
i] - b[
i]);
206 template <std::
size_t Dim>
208 static torch::Tensor
invoke(torch::Tensor &embedFeatures,
float rVal,
210 assert(embedFeatures.size(1) == Dim);
211 embedFeatures = embedFeatures.to(torch::kCPU);
218 typename KDTree::vector_t features;
219 features.reserve(embedFeatures.size(0));
221 auto dataPtr = embedFeatures.data_ptr<
float>();
223 for (
int i = 0;
i < embedFeatures.size(0); ++
i) {
232 std::vector<int32_t>
edges;
233 edges.reserve(2 * kVal * embedFeatures.size(0));
235 for (
int iself = 0; iself < embedFeatures.size(0); ++iself) {
239 for (
auto j = 0ul;
j < Dim; ++
j) {
243 tree.rangeSearchMapDiscard(
245 if (iself != iother &&
dist(
self, other) <= rVal) {
246 edges.push_back(iself);
247 edges.push_back(iother);
258 float rVal,
int kVal,
259 bool flipDirections) {
260 auto tensor = Acts::template_switch<BuildEdgesKDTree, 1, 12>(
261 embedFeatures.size(1), embedFeatures, rVal, kVal);
267 int kVal,
bool flipDirections) {
268 #ifndef ACTS_EXATRKX_CPUONLY
269 if (torch::cuda::is_available()) {