Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExaTrkXEdgeBuildingTest.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file ExaTrkXEdgeBuildingTest.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2022 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <boost/test/unit_test.hpp>
10 
14 
15 #include <cassert>
16 #include <iostream>
17 
18 #include <Eigen/Core>
19 #include <torch/torch.h>
20 
22 
23 #define PRINT 0
24 
25 float distance(const at::Tensor &a, const at::Tensor &b) {
26  assert(a.sizes() == b.sizes());
27  assert(a.sizes().size() == 1);
28 
29  return std::sqrt(((a - b) * (a - b)).sum().item().to<float>());
30 }
31 
32 #if PRINT
33 std::ostream &operator<<(std::ostream &os, CantorPair p) {
34  auto [a, b] = p.inverse();
35  os << "(" << a << "," << b << ")";
36  return os;
37 }
38 #endif
39 
40 template <typename edge_builder_t>
41 void test_random_graph(int emb_dim, int n_nodes, float r, int knn,
42  const edge_builder_t &edgeBuilder) {
43  // Create a random point cloud
44  auto random_features = at::randn({n_nodes, emb_dim});
45 
46  // Generate the truth via brute-force
47  Eigen::MatrixXf distance_matrix(n_nodes, n_nodes);
48 
49  std::vector<CantorPair> edges_ref_cantor;
50  std::vector<int> edge_counts(n_nodes, 0);
51 
52  for (int i = 0; i < n_nodes; ++i) {
53  for (int j = i; j < n_nodes; ++j) {
54  const auto d = distance(random_features[i], random_features[j]);
55  distance_matrix(i, j) = d;
56  distance_matrix(j, i) = d;
57 
58  if (d < r && i != j) {
59  edges_ref_cantor.emplace_back(i, j);
60  edge_counts[i]++;
61  }
62  }
63  }
64 
65  const auto max_edges =
66  *std::max_element(edge_counts.begin(), edge_counts.end());
67 
68  // If this is not the case, the test is ill-formed
69  // knn specifies how many edges can be found by the function at max. Thus, we
70  // should design the test in a way, that our brute-force test algorithm does
71  // not find more edges than the algorithm that we test against it can find
72  BOOST_REQUIRE(max_edges <= knn);
73 
74  // Run the edge building
75  auto edges_test = edgeBuilder(random_features, r, knn);
76 
77  // Map the edges to cantor pairs
78  std::vector<CantorPair> edges_test_cantor;
79 
80  for (int i = 0; i < edges_test.size(1); ++i) {
81  const auto a = edges_test[0][i].template item<int>();
82  const auto b = edges_test[1][i].template item<int>();
83  edges_test_cantor.push_back(a < b ? CantorPair(a, b) : CantorPair(b, a));
84  }
85 
86  std::sort(edges_ref_cantor.begin(), edges_ref_cantor.end());
87  std::sort(edges_test_cantor.begin(), edges_test_cantor.end());
88 
89 #if PRINT
90  std::cout << "test size " << edges_test_cantor.size() << std::endl;
91  std::cout << "ref size " << edges_ref_cantor.size() << std::endl;
92  std::cout << "test: ";
93  std::copy(
94  edges_test_cantor.begin(),
95  edges_test_cantor.begin() + std::min(edges_test_cantor.size(), 10ul),
96  std::ostream_iterator<CantorPair>(std::cout, " "));
97  std::cout << std::endl;
98  std::cout << "ref: ";
99  std::copy(edges_ref_cantor.begin(),
100  edges_ref_cantor.begin() + std::min(edges_ref_cantor.size(), 10ul),
101  std::ostream_iterator<CantorPair>(std::cout, " "));
102  std::cout << std::endl;
103 #endif
104 
105  // Check
106  BOOST_CHECK(edges_ref_cantor.size() == edges_test_cantor.size());
107  BOOST_CHECK(std::equal(edges_test_cantor.begin(), edges_test_cantor.end(),
108  edges_ref_cantor.begin()));
109 }
110 
111 BOOST_AUTO_TEST_CASE(test_cantor_pair_functions) {
112  int a = 345;
113  int b = 23;
114  // Use non-sorted cantor pair to make this work
115  const auto [aa, bb] = CantorPair(a, b, false).inverse();
116  BOOST_CHECK(a == aa);
117  BOOST_CHECK(b == bb);
118 }
119 
120 BOOST_AUTO_TEST_CASE(test_cantor_pair_sorted) {
121  int a = 345;
122  int b = 23;
123  CantorPair c1(a, b);
124  CantorPair c2(b, a);
125  BOOST_CHECK(c1.value() == c2.value());
126 }
127 
128 const int emb_dim = 3;
129 const int n_nodes = 20;
130 const float r = 1.5;
131 const int knn = 50;
132 const int seed = 42;
133 
134 BOOST_AUTO_TEST_CASE(test_random_graph_edge_building_cuda,
135  *boost::unit_test::precondition([](auto) {
136  return torch::cuda::is_available();
137  })) {
138  torch::manual_seed(seed);
139 
140  auto cudaEdgeBuilder = [](auto &features, auto radius, auto k) {
141  auto features_cuda = features.to(torch::kCUDA);
142  return Acts::detail::buildEdgesFRNN(features_cuda, radius, k);
143  };
144 
145  test_random_graph(emb_dim, n_nodes, r, knn, cudaEdgeBuilder);
146 }
147 
148 BOOST_AUTO_TEST_CASE(test_random_graph_edge_building_kdtree) {
149  torch::manual_seed(seed);
150 
151  auto cpuEdgeBuilder = [](auto &features, auto radius, auto k) {
152  auto features_cpu = features.to(torch::kCPU);
153  return Acts::detail::buildEdgesKDTree(features_cpu, radius, k);
154  };
155 
156  test_random_graph(emb_dim, n_nodes, r, knn, cpuEdgeBuilder);
157 }
158 
159 BOOST_AUTO_TEST_CASE(test_self_loop_removal) {
160  // clang-format off
161  std::vector<int64_t> edges = {
162  1,1,
163  2,3,
164  2,2,
165  5,4,
166  };
167  // clang-format on
168 
169  auto opts = torch::TensorOptions().dtype(torch::kInt64);
170  const auto edgeTensor =
171  torch::from_blob(edges.data(), {static_cast<long>(edges.size() / 2), 2},
172  opts)
173  .transpose(0, 1);
174 
175  const auto withoutSelfLoops =
176  Acts::detail::postprocessEdgeTensor(edgeTensor, true, false, false)
177  .transpose(1, 0)
178  .flatten();
179 
180  const std::vector<int64_t> postEdges(
181  withoutSelfLoops.data_ptr<int64_t>(),
182  withoutSelfLoops.data_ptr<int64_t>() + withoutSelfLoops.numel());
183 
184  // clang-format off
185  const std::vector<int64_t> ref = {
186  2,3,
187  5,4,
188  };
189  // clang-format on
190 
191  BOOST_CHECK(ref == postEdges);
192 }
193 
194 BOOST_AUTO_TEST_CASE(test_duplicate_removal) {
195  // clang-format off
196  std::vector<int64_t> edges = {
197  1,2,
198  2,1, // duplicate, flipped
199  3,2,
200  3,2, // duplicate, not flipped
201  7,6, // should be flipped
202  };
203  // clang-format on
204 
205  auto opts = torch::TensorOptions().dtype(torch::kInt64);
206  const auto edgeTensor =
207  torch::from_blob(edges.data(), {static_cast<long>(edges.size() / 2), 2},
208  opts)
209  .transpose(0, 1);
210 
211  const auto withoutDups =
212  Acts::detail::postprocessEdgeTensor(edgeTensor, false, true, false)
213  .transpose(1, 0)
214  .flatten();
215 
216  const std::vector<int64_t> postEdges(
217  withoutDups.data_ptr<int64_t>(),
218  withoutDups.data_ptr<int64_t>() + withoutDups.numel());
219 
220  // clang-format off
221  const std::vector<int64_t> ref = {
222  1,2,
223  2,3,
224  6,7,
225  };
226  // clang-format on
227 
228  BOOST_CHECK(ref == postEdges);
229 }
230 
231 BOOST_AUTO_TEST_CASE(test_random_flip) {
232  torch::manual_seed(seed);
233 
234  // clang-format off
235  std::vector<int64_t> edges = {
236  1,2,
237  2,3,
238  3,4,
239  4,5,
240  };
241  // clang-format on
242 
243  auto opts = torch::TensorOptions().dtype(torch::kInt64);
244  const auto edgeTensor =
245  torch::from_blob(edges.data(), {static_cast<long>(edges.size() / 2), 2},
246  opts)
247  .transpose(0, 1);
248 
249  const auto flipped =
250  Acts::detail::postprocessEdgeTensor(edgeTensor, false, false, true)
251  .transpose(0, 1)
252  .flatten();
253 
254  const std::vector<int64_t> postEdges(
255  flipped.data_ptr<int64_t>(),
256  flipped.data_ptr<int64_t>() + flipped.numel());
257 
258  BOOST_CHECK(postEdges.size() == edges.size());
259  for (auto preIt = edges.begin(); preIt != edges.end(); preIt += 2) {
260  int found = 0;
261 
262  for (auto postIt = postEdges.begin(); postIt != postEdges.end();
263  postIt += 2) {
264  bool noflp = (*preIt == *postIt) and *(preIt + 1) == *(postIt + 1);
265  bool flp = *preIt == *(postIt + 1) and *(preIt + 1) == *(postIt);
266 
267  found += (flp or noflp);
268  }
269 
270  BOOST_CHECK(found == 1);
271  }
272 }