Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
buildEdges.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file buildEdges.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2022 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
14 
15 #include <iostream>
16 #include <mutex>
17 #include <vector>
18 
19 #include <torch/script.h>
20 #include <torch/torch.h>
21 
22 #ifndef ACTS_EXATRKX_CPUONLY
23 #include <cuda.h>
24 #include <cuda_runtime_api.h>
25 #include <grid/counting_sort.h>
26 #include <grid/find_nbrs.h>
27 #include <grid/grid.h>
28 #include <grid/insert_points.h>
29 #include <grid/prefix_sum.h>
30 #endif
31 
32 using namespace torch::indexing;
33 
34 torch::Tensor Acts::detail::postprocessEdgeTensor(torch::Tensor edges,
35  bool removeSelfLoops,
36  bool removeDuplicates,
37  bool flipDirections) {
38  // Remove self-loops
39  if (removeSelfLoops) {
40  torch::Tensor selfLoopMask = edges.index({0}) != edges.index({1});
41  edges = edges.index({Slice(), selfLoopMask});
42  }
43 
44  // Remove duplicates
45  if (removeDuplicates) {
46  torch::Tensor mask = edges.index({0}) > edges.index({1});
47  edges.index_put_({Slice(), mask}, edges.index({Slice(), mask}).flip(0));
48  edges = std::get<0>(torch::unique_dim(edges, -1, false));
49  }
50 
51  // Randomly flip direction
52  if (flipDirections) {
53  torch::Tensor random_cut_keep = torch::randint(2, {edges.size(1)});
54  torch::Tensor random_cut_flip = 1 - random_cut_keep;
55  torch::Tensor keep_edges =
56  edges.index({Slice(), random_cut_keep.to(torch::kBool)});
57  torch::Tensor flip_edges =
58  edges.index({Slice(), random_cut_flip.to(torch::kBool)}).flip({0});
59  edges = torch::cat({keep_edges, flip_edges}, 1);
60  }
61 
62  return edges.toType(torch::kInt64);
63 }
64 
65 torch::Tensor Acts::detail::buildEdgesFRNN(torch::Tensor &embedFeatures,
66  float rVal, int kVal,
67  bool flipDirections) {
68 #ifndef ACTS_EXATRKX_CPUONLY
69  const auto device = embedFeatures.device();
70 
71  const int64_t numSpacepoints = embedFeatures.size(0);
72  const int dim = embedFeatures.size(1);
73 
74  const int grid_params_size = 8;
75  const int grid_delta_idx = 3;
76  const int grid_total_idx = 7;
77  const int grid_max_res = 128;
78  const int grid_dim = 3;
79 
80  if (dim < 3) {
81  throw std::runtime_error("DIM < 3 is not supported for now.\n");
82  }
83 
84  const float radius_cell_ratio = 2.0;
85  const int batch_size = 1;
86  int G = -1;
87 
88  // Set up grid properties
89  torch::Tensor grid_min;
90  torch::Tensor grid_max;
91  torch::Tensor grid_size;
92 
93  torch::Tensor embedTensor = embedFeatures.reshape({1, numSpacepoints, dim});
94  torch::Tensor gridParamsCuda =
95  torch::zeros({batch_size, grid_params_size}, device).to(torch::kFloat32);
96  torch::Tensor r_tensor = torch::full({batch_size}, rVal, device);
97  torch::Tensor lengths = torch::full({batch_size}, numSpacepoints, device);
98 
99  // build the grid
100  for (int i = 0; i < batch_size; i++) {
101  torch::Tensor allPoints =
102  embedTensor.index({i, Slice(None, lengths.index({i}).item().to<long>()),
103  Slice(None, grid_dim)});
104  grid_min = std::get<0>(allPoints.min(0));
105  grid_max = std::get<0>(allPoints.max(0));
106  gridParamsCuda.index_put_({i, Slice(None, grid_delta_idx)}, grid_min);
107 
108  grid_size = grid_max - grid_min;
109 
110  float cell_size =
111  r_tensor.index({i}).item().to<float>() / radius_cell_ratio;
112 
113  if (cell_size < (grid_size.min().item().to<float>() / grid_max_res)) {
114  cell_size = grid_size.min().item().to<float>() / grid_max_res;
115  }
116 
117  gridParamsCuda.index_put_({i, grid_delta_idx}, 1 / cell_size);
118 
119  gridParamsCuda.index_put_({i, Slice(1 + grid_delta_idx, grid_total_idx)},
120  floor(grid_size / cell_size) + 1);
121 
122  gridParamsCuda.index_put_(
123  {i, grid_total_idx},
124  gridParamsCuda.index({i, Slice(1 + grid_delta_idx, grid_total_idx)})
125  .prod());
126 
127  if (G < gridParamsCuda.index({i, grid_total_idx}).item().to<int>()) {
128  G = gridParamsCuda.index({i, grid_total_idx}).item().to<int>();
129  }
130  }
131 
132  torch::Tensor pc_grid_cnt =
133  torch::zeros({batch_size, G}, device).to(torch::kInt32);
134  torch::Tensor pc_grid_cell =
135  torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
136  torch::Tensor pc_grid_idx =
137  torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
138 
139  // put spacepoints into the grid
140  InsertPointsCUDA(embedTensor, lengths.to(torch::kInt64), gridParamsCuda,
141  pc_grid_cnt, pc_grid_cell, pc_grid_idx, G);
142 
143  torch::Tensor pc_grid_off =
144  torch::full({batch_size, G}, 0, device).to(torch::kInt32);
145  torch::Tensor grid_params = gridParamsCuda.to(torch::kCPU);
146 
147  // for loop seems not to be necessary anymore
148  pc_grid_off = PrefixSumCUDA(pc_grid_cnt, grid_params);
149 
150  torch::Tensor sorted_points =
151  torch::zeros({batch_size, numSpacepoints, dim}, device)
152  .to(torch::kFloat32);
153  torch::Tensor sorted_points_idxs =
154  torch::full({batch_size, numSpacepoints}, -1, device).to(torch::kInt32);
155 
156  CountingSortCUDA(embedTensor, lengths.to(torch::kInt64), pc_grid_cell,
157  pc_grid_idx, pc_grid_off, sorted_points, sorted_points_idxs);
158 
159  auto [indices, distances] = FindNbrsCUDA(
160  sorted_points, sorted_points, lengths.to(torch::kInt64),
161  lengths.to(torch::kInt64), pc_grid_off.to(torch::kInt32),
162  sorted_points_idxs, sorted_points_idxs,
163  gridParamsCuda.to(torch::kFloat32), kVal, r_tensor, r_tensor * r_tensor);
164  torch::Tensor positiveIndices = indices >= 0;
165 
166  torch::Tensor repeatRange = torch::arange(positiveIndices.size(1), device)
167  .repeat({1, positiveIndices.size(2), 1})
168  .transpose(1, 2);
169 
170  torch::Tensor stackedEdges = torch::stack(
171  {repeatRange.index({positiveIndices}), indices.index({positiveIndices})});
172 
173  return postprocessEdgeTensor(std::move(stackedEdges), true, true,
174  flipDirections);
175 #else
176  throw std::runtime_error(
177  "ACTS not compiled with CUDA, cannot run Acts::buildEdgesFRNN");
178 #endif
179 }
180 
184 template <typename T, std::size_t S>
185 struct Span {
186  T *ptr;
187 
188  auto size() const { return S; }
189 
190  using const_iterator = T const *;
191  const_iterator cbegin() const { return ptr; }
192  const_iterator cend() const { return ptr + S; }
193 
194  auto operator[](std::size_t i) const { return ptr[i]; }
195 };
196 
197 template <std::size_t Dim>
198 float dist(const Span<float, Dim> &a, const Span<float, Dim> &b) {
199  float s = 0.f;
200  for (auto i = 0ul; i < Dim; ++i) {
201  s += (a[i] - b[i]) * (a[i] - b[i]);
202  }
203  return std::sqrt(s);
204 };
205 
206 template <std::size_t Dim>
208  static torch::Tensor invoke(torch::Tensor &embedFeatures, float rVal,
209  int kVal) {
210  assert(embedFeatures.size(1) == Dim);
211  embedFeatures = embedFeatures.to(torch::kCPU);
212 
214  // Build tree //
217 
218  typename KDTree::vector_t features;
219  features.reserve(embedFeatures.size(0));
220 
221  auto dataPtr = embedFeatures.data_ptr<float>();
222 
223  for (int i = 0; i < embedFeatures.size(0); ++i) {
224  features.push_back({Span<float, Dim>{dataPtr + i * Dim}, i});
225  }
226 
227  KDTree tree(std::move(features));
228 
230  // Search tree //
232  std::vector<int32_t> edges;
233  edges.reserve(2 * kVal * embedFeatures.size(0));
234 
235  for (int iself = 0; iself < embedFeatures.size(0); ++iself) {
236  const Span<float, Dim> self{dataPtr + iself * Dim};
237 
239  for (auto j = 0ul; j < Dim; ++j) {
240  range[j] = Acts::Range1D(self[j] - rVal, self[j] + rVal);
241  }
242 
243  tree.rangeSearchMapDiscard(
244  range, [&](const Span<float, Dim> &other, const int &iother) {
245  if (iself != iother && dist(self, other) <= rVal) {
246  edges.push_back(iself);
247  edges.push_back(iother);
248  }
249  });
250  }
251 
252  // Transpose is necessary here, clone to get ownership
253  return Acts::detail::vectorToTensor2D(edges, 2).t().clone();
254  }
255 };
256 
257 torch::Tensor Acts::detail::buildEdgesKDTree(torch::Tensor &embedFeatures,
258  float rVal, int kVal,
259  bool flipDirections) {
260  auto tensor = Acts::template_switch<BuildEdgesKDTree, 1, 12>(
261  embedFeatures.size(1), embedFeatures, rVal, kVal);
262 
263  return postprocessEdgeTensor(tensor, true, true, flipDirections);
264 }
265 
266 torch::Tensor Acts::detail::buildEdges(torch::Tensor &embedFeatures, float rVal,
267  int kVal, bool flipDirections) {
268 #ifndef ACTS_EXATRKX_CPUONLY
269  if (torch::cuda::is_available()) {
270  return detail::buildEdgesFRNN(embedFeatures, rVal, kVal, flipDirections);
271  } else {
272  return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
273  }
274 #else
275  return detail::buildEdgesKDTree(embedFeatures, rVal, kVal, flipDirections);
276 #endif
277 }