Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
KDTree.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file KDTree.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
12 
13 #include <algorithm>
14 #include <array>
15 #include <cmath>
16 #include <functional>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 namespace Acts {
43 template <std::size_t Dims, typename Type, typename Scalar = double,
44  template <typename, std::size_t> typename Vector = std::array,
45  std::size_t LeafSize = 4>
46 class KDTree {
47  public:
49  using value_t = Type;
50 
53 
55  using coordinate_t = Vector<Scalar, Dims>;
56 
58  using pair_t = std::pair<coordinate_t, Type>;
59 
61  using vector_t = std::vector<pair_t>;
62 
64  using iterator_t = typename vector_t::iterator;
65 
66  using const_iterator_t = typename vector_t::const_iterator;
67 
68  // We do not need an empty constructor - this is never useful.
69  KDTree() = delete;
70 
78  KDTree(vector_t &&d) : m_elems(d) {
79  // To start out, we need to check whether we need to construct a leaf node
80  // or an internal node. We create a leaf only if we have at most as many
81  // elements as the number of elements that can fit into a leaf node.
82  // Hopefully most invocations of this constructor will have more than a few
83  // elements!
84  //
85  // One interesting thing to note is that all of the nodes in the k-d tree
86  // have a range in the element vector of the outermost node. They simply
87  // make in-place changes to this array, and they hold no memory of their
88  // own.
89  m_root = std::make_unique<KDTreeNode>(m_elems.begin(), m_elems.end(),
90  m_elems.size() > LeafSize
91  ? KDTreeNode::NodeType::Internal
92  : KDTreeNode::NodeType::Leaf,
93  0UL);
94  }
95 
106  std::vector<Type> rangeSearch(const range_t &r) const {
107  std::vector<Type> out;
108 
109  rangeSearch(r, out);
110 
111  return out;
112  }
113 
124  std::vector<pair_t> rangeSearchWithKey(const range_t &r) const {
125  std::vector<pair_t> out;
126 
127  rangeSearchWithKey(r, out);
128 
129  return out;
130  }
131 
139  void rangeSearch(const range_t &r, std::vector<Type> &v) const {
140  rangeSearchInserter(r, std::back_inserter(v));
141  }
142 
151  void rangeSearchWithKey(const range_t &r, std::vector<pair_t> &v) const {
152  rangeSearchInserterWithKey(r, std::back_inserter(v));
153  }
154 
165  template <typename OutputIt>
166  void rangeSearchInserter(const range_t &r, OutputIt i) const {
168  r, [i](const coordinate_t &, const Type &v) mutable { i = v; });
169  }
170 
181  template <typename OutputIt>
182  void rangeSearchInserterWithKey(const range_t &r, OutputIt i) const {
183  rangeSearchMapDiscard(r, [i](const coordinate_t &c, const Type &v) mutable {
184  i = {c, v};
185  });
186  }
187 
206  template <typename Result>
207  std::vector<Result> rangeSearchMap(
208  const range_t &r,
209  std::function<Result(const coordinate_t &, const Type &)> f) const {
210  std::vector<Result> out;
211 
212  rangeSearchMapInserter(r, f, std::back_inserter(out));
213 
214  return out;
215  }
216 
233  template <typename Result, typename OutputIt>
235  const range_t &r,
236  std::function<Result(const coordinate_t &, const Type &)> f,
237  OutputIt i) const {
238  rangeSearchMapDiscard(r, [i, f](const coordinate_t &c,
239  const Type &v) mutable { i = f(c, v); });
240  }
241 
253  template <typename Callable>
254  void rangeSearchMapDiscard(const range_t &r, Callable &&f) const {
255  m_root->rangeSearchMapDiscard(r, std::forward<Callable>(f));
256  }
257 
263  std::size_t size(void) const { return m_root->size(); }
264 
265  const_iterator_t begin(void) const { return m_elems.begin(); }
266 
267  const_iterator_t end(void) const { return m_elems.end(); }
268 
269  private:
271  // I'm not super happy with this bit of code, but since 1D ranges are
272  // semi-open, we can't simply incorporate values by setting the maximum to
273  // them. Instead, what we need to do is get the next representable value.
274  // For integer values, this means adding one. For floating point types, we
275  // rely on the nextafter method to get the smallest possible value that is
276  // larger than the one we requested.
277  if constexpr (std::is_integral_v<Scalar>) {
278  return v + 1;
279  } else if constexpr (std::is_floating_point_v<Scalar>) {
280  return std::nextafter(v, std::numeric_limits<Scalar>::max());
281  }
282  }
283 
285  // Firstly, we find the minimum and maximum value in each dimension to
286  // construct a bounding box around this node's values.
287  std::array<Scalar, Dims> min_v{}, max_v{};
288 
289  for (std::size_t i = 0; i < Dims; ++i) {
290  min_v[i] = std::numeric_limits<Scalar>::max();
291  max_v[i] = std::numeric_limits<Scalar>::lowest();
292  }
293 
294  for (iterator_t i = b; i != e; ++i) {
295  for (std::size_t j = 0; j < Dims; ++j) {
296  min_v[j] = std::min(min_v[j], i->first[j]);
297  max_v[j] = std::max(max_v[j], i->first[j]);
298  }
299  }
300 
301  // Then, we construct a k-dimensional range from the given minima and
302  // maxima, which again is just a bounding box.
303  range_t r;
304 
305  for (std::size_t j = 0; j < Dims; ++j) {
306  r[j] = {min_v[j], nextRepresentable(max_v[j])};
307  }
308 
309  return r;
310  }
311 
318  class KDTreeNode {
319  public:
321  enum class NodeType { Internal, Leaf };
322 
329  KDTreeNode(iterator_t _b, iterator_t _e, NodeType _t, std::size_t _d)
330  : m_type(_t),
331  m_begin_it(_b),
332  m_end_it(_e),
334  if (m_type == NodeType::Internal) {
335  // This constant determines the maximum number of elements where we
336  // still
337  // calculate the exact median of the values for the purposes of
338  // splitting. In general, the closer the pivot value is to the true
339  // median, the more balanced the tree will be. However, calculating the
340  // median exactly is an O(n log n) operation, while approximating it is
341  // an O(1) time.
342  constexpr std::size_t max_exact_median = 128;
343 
344  iterator_t pivot;
345 
346  // Next, we need to determine the pivot point of this node, that is to
347  // say the point in the selected pivot dimension along which point we
348  // will split the range. To do this, we check how large the set of
349  // elements is. If it is sufficiently small, we use the median.
350  // Otherwise we use the mean.
351  if (size() > max_exact_median) {
352  // In this case, we have a lot of elements, and sorting the range to
353  // find the true median might be too expensive. Therefore, we will
354  // just use the middle value between the minimum and maximum. This is
355  // not nearly as accurate as using the median, but it's a nice cheat.
356  Scalar mid = static_cast<Scalar>(0.5) *
357  (m_range[_d].max() + m_range[_d].min());
358 
359  pivot = std::partition(m_begin_it, m_end_it, [=](const pair_t &i) {
360  return i.first[_d] < mid;
361  });
362  } else {
363  // If the number of elements is fairly small, we will just calculate
364  // the median exactly. We do this by finding the values in the
365  // dimension, sorting it, and then taking the middle one.
367  [_d](const typename iterator_t::value_type &a,
368  const typename iterator_t::value_type &b) {
369  return a.first[_d] < b.first[_d];
370  });
371 
372  pivot = m_begin_it + (std::distance(m_begin_it, m_end_it) / 2);
373  }
374 
375  // This should never really happen, but in very select cases where there
376  // are a lot of equal values in the range, the pivot can end up all the
377  // way at the end of the array and we end up in an infinite loop. We
378  // check for pivot points which would not split the range, and fix them
379  // if they occur.
380  if (pivot == m_begin_it || pivot == std::prev(m_end_it)) {
381  pivot = std::next(m_begin_it, LeafSize);
382  }
383 
384  // Calculate the number of elements on the left-hand side, as well as
385  // the right-hand side. We do this by calculating the difference from
386  // the begin and end of the array to the pivot point.
387  std::size_t lhs_size = std::distance(m_begin_it, pivot);
388  std::size_t rhs_size = std::distance(pivot, m_end_it);
389 
390  // Next, we check whether the left-hand node should be another internal
391  // node or a leaf node, and we construct the node recursively.
392  m_lhs = std::make_unique<KDTreeNode>(
393  m_begin_it, pivot,
394  lhs_size > LeafSize ? NodeType::Internal : NodeType::Leaf,
395  (_d + 1) % Dims);
396 
397  // Same on the right hand side.
398  m_rhs = std::make_unique<KDTreeNode>(
399  pivot, m_end_it,
400  rhs_size > LeafSize ? NodeType::Internal : NodeType::Leaf,
401  (_d + 1) % Dims);
402  }
403  }
404 
414  template <typename Callable>
415  void rangeSearchMapDiscard(const range_t &r, Callable &&f) const {
416  // Determine whether the range completely covers the bounding box of
417  // this leaf node. If it is, we can copy all values without having to
418  // check for them being inside the range again.
419  bool contained = r >= m_range;
420 
421  if (m_type == NodeType::Internal) {
422  // Firstly, we can check if the range completely contains the bounding
423  // box of this node. If that is the case, we know for certain that any
424  // value contained below this node should end up in the output, and we
425  // can stop recursively looking for them.
426  if (contained) {
427  // We can also pre-allocate space for the number of elements, since we
428  // are inserting all of them anyway.
429  for (iterator_t i = m_begin_it; i != m_end_it; ++i) {
430  f(i->first, i->second);
431  }
432 
433  return;
434  }
435 
436  assert(m_lhs && m_rhs && "Did not find lhs and rhs");
437 
438  // If we have a left-hand node (which we should!), then we check if
439  // there is any overlap between the target range and the bounding box of
440  // the left-hand node. If there is, we recursively search in that node.
441  if (m_lhs->range() && r) {
442  m_lhs->rangeSearchMapDiscard(r, std::forward<Callable>(f));
443  }
444 
445  // Then, we perform exactly the same procedure for the right hand side.
446  if (m_rhs->range() && r) {
447  m_rhs->rangeSearchMapDiscard(r, std::forward<Callable>(f));
448  }
449  } else {
450  // Iterate over all the elements in this leaf node. This should be a
451  // relatively small number (the LeafSize template parameter).
452  for (iterator_t i = m_begin_it; i != m_end_it; ++i) {
453  // We need to check whether the element is actually inside the range.
454  // In case this node's bounding box is fully contained within the
455  // range, we don't actually need to check this.
456  if (contained || r.contains(i->first)) {
457  f(i->first, i->second);
458  }
459  }
460  }
461  }
462 
470  std::size_t size() const { return std::distance(m_begin_it, m_end_it); }
471 
477  const range_t &range() const { return m_range; }
478 
479  protected:
481 
485 
489 
491  std::unique_ptr<KDTreeNode> m_lhs;
492  std::unique_ptr<KDTreeNode> m_rhs;
493  };
494 
498 
500  std::unique_ptr<KDTreeNode> m_root;
501 };
502 } // namespace Acts