Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Clusterization.ipp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file Clusterization.ipp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2022 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <array>
10 #include <vector>
11 
12 #include <boost/pending/disjoint_sets.hpp>
13 
14 namespace Acts::Ccl::internal {
15 
16 // Machinery for validating generic Cell/Cluster types at compile-time
17 
18 template <typename, size_t, typename T = void>
20 
21 template <typename T>
23  T, 2,
24  std::void_t<decltype(getCellRow(std::declval<T>())),
25  decltype(getCellColumn(std::declval<T>())),
26  decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
27 };
28 
29 template <typename T>
31  T, 1,
32  std::void_t<decltype(getCellColumn(std::declval<T>())),
33  decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
34 };
35 
36 template <typename, typename, typename T = void>
38 
39 template <typename T, typename U>
41  T, U,
42  std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
43  : std::true_type {};
44 
45 template <size_t GridDim>
46 constexpr void staticCheckGridDim() {
47  static_assert(
48  GridDim == 1 || GridDim == 2,
49  "mergeClusters is only defined for grid dimensions of 1 or 2. ");
50 }
51 
52 template <typename T, size_t GridDim>
53 constexpr void staticCheckCellType() {
54  constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
55  static_assert(hasFns,
56  "Cell type should have the following functions: "
57  "'int getCellRow(const Cell&)', "
58  "'int getCellColumn(const Cell&)', "
59  "'Label& getCellLabel(Cell&)'");
60 }
61 
62 template <typename T, typename U>
63 constexpr void staticCheckClusterType() {
64  constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
65  static_assert(hasFns,
66  "Cluster type should have the following function: "
67  "'void clusterAddCell(Cluster&, const Cell&)'");
68 }
69 
70 template <typename Cell, size_t GridDim>
71 struct Compare {
72  static_assert(GridDim != 1 && GridDim != 2,
73  "Only grid dimensions of 1 or 2 are supported");
74 };
75 
76 // Comparator function object for cells, column-wise ordering
77 // Specialization for 2-D grid
78 template <typename Cell>
79 struct Compare<Cell, 2> {
80  bool operator()(const Cell& c0, const Cell& c1) const {
81  int row0 = getCellRow(c0);
82  int row1 = getCellRow(c1);
83  int col0 = getCellColumn(c0);
84  int col1 = getCellColumn(c1);
85  return (col0 == col1) ? row0 < row1 : col0 < col1;
86  }
87 };
88 
89 // Specialization for 1-D grids
90 template <typename Cell>
91 struct Compare<Cell, 1> {
92  bool operator()(const Cell& c0, const Cell& c1) const {
93  int col0 = getCellColumn(c0);
94  int col1 = getCellColumn(c1);
95  return col0 < col1;
96  }
97 };
98 
99 // Simple wrapper around boost::disjoint_sets. In theory, could use
100 // boost::vector_property_map and use boost::disjoint_sets without
101 // wrapping, but it's way slower
103  public:
104  explicit DisjointSets(size_t initial_size = 128)
105  : m_size(initial_size),
106  m_rank(m_size),
107  m_parent(m_size),
108  m_ds(&m_rank[0], &m_parent[0]) {}
109 
111  // Empirically, m_size = 128 seems to be good default. If we
112  // exceed this, take a performance hit and do the right thing.
113  while (m_globalId >= m_size) {
114  m_size *= 2;
115  m_rank.resize(m_size);
116  m_parent.resize(m_size);
117  m_ds = boost::disjoint_sets<size_t*, size_t*>(&m_rank[0], &m_parent[0]);
118  }
119  m_ds.make_set(m_globalId);
120  return static_cast<Label>(m_globalId++);
121  }
122 
123  void unionSet(size_t x, size_t y) { m_ds.union_set(x, y); }
124  Label findSet(size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
125 
126  private:
127  size_t m_globalId = 1;
128  size_t m_size;
129  std::vector<size_t> m_rank;
130  std::vector<size_t> m_parent;
131  boost::disjoint_sets<size_t*, size_t*> m_ds;
132 };
133 
134 template <size_t BufSize>
136  size_t nconn{0};
137  std::array<Label, BufSize> buf;
138  ConnectionsBase() { std::fill(buf.begin(), buf.end(), NO_LABEL); }
139 };
140 
141 template <size_t GridDim>
142 class Connections {};
143 
144 // On 1-D grid, cells have 1 backward neighbor
145 template <>
146 struct Connections<1> : public ConnectionsBase<1> {
148 };
149 
150 // On a 2-D grid, cells have 4 backward neighbors
151 template <>
152 struct Connections<2> : public ConnectionsBase<4> {
154 };
155 
156 // Cell collection logic
157 template <typename Cell, typename Connect, size_t GridDim>
158 Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
159  std::vector<Cell>& set, Connect connect) {
161  typename std::vector<Cell>::iterator it_2{it};
162 
163  while (it_2 != set.begin()) {
164  it_2 = std::prev(it_2);
165 
166  ConnectResult cr = connect(*it, *it_2);
167  if (cr == ConnectResult::eNoConnStop) {
168  break;
169  }
170  if (cr == ConnectResult::eNoConn) {
171  continue;
172  }
173  if (cr == ConnectResult::eConn) {
174  seen.buf[seen.nconn] = getCellLabel(*it_2);
175  seen.nconn += 1;
176  if (seen.nconn == seen.buf.size()) {
177  break;
178  }
179  }
180  }
181  return seen;
182 }
183 
184 template <typename CellCollection, typename ClusterCollection>
185 ClusterCollection mergeClustersImpl(CellCollection& cells) {
186  using Cluster = typename ClusterCollection::value_type;
187 
188  if (cells.empty()) {
189  return {};
190  }
191 
192  // Accumulate clusters into the output collection
193  ClusterCollection outv;
194  Cluster cl;
195  int lbl = getCellLabel(cells.front());
196  for (auto& cell : cells) {
197  if (getCellLabel(cell) != lbl) {
198  // New cluster, save previous one
199  outv.push_back(std::move(cl));
200  cl = Cluster();
201  lbl = getCellLabel(cell);
202  }
203  clusterAddCell(cl, cell);
204  }
205  // Get the last cluster as well
206  outv.push_back(std::move(cl));
207 
208  return outv;
209 }
210 
211 } // namespace Acts::Ccl::internal
212 
213 namespace Acts::Ccl {
214 
215 template <typename Cell>
216 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
217  const Cell& iter) const {
218  int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
219  int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
220  // Iteration is column-wise, so if too far in column, can
221  // safely stop
222  if (deltaCol > 1) {
223  return ConnectResult::eNoConnStop;
224  }
225  // For same reason, if too far in row we know the pixel is not
226  // connected, but need to keep iterating
227  if (deltaRow > 1) {
228  return ConnectResult::eNoConn;
229  }
230  // Decide whether or not cluster is connected based on 4- or
231  // 8-connectivity
232  if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
233  return ConnectResult::eConn;
234  }
235  return ConnectResult::eNoConn;
236 }
237 
238 template <typename Cell>
240  const Cell& iter) const {
241  int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
242  return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
243 }
244 
245 template <size_t GridDim>
248  // Sanity check: first element should always have
249  // label if nconn > 0
250  if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
251  throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
252  }
253  for (size_t i = 1; i < seen.nconn; i++) {
254  // Sanity check: since connection lookup is always backward
255  // while iteration is forward, all connected cells found here
256  // should have a label
257  if (seen.buf[i] == NO_LABEL) {
258  throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
259  }
260  // Only record equivalence if needed
261  if (seen.buf[0] != seen.buf[i]) {
262  ds.unionSet(seen.buf[0], seen.buf[i]);
263  }
264  }
265 }
266 
267 template <typename CellCollection, size_t GridDim, typename Connect>
268 void labelClusters(CellCollection& cells, Connect connect) {
269  using Cell = typename CellCollection::value_type;
270  internal::staticCheckCellType<Cell, GridDim>();
271 
273 
274  // Sort cells by position to enable in-order scan
275  std::sort(cells.begin(), cells.end(), internal::Compare<Cell, GridDim>());
276 
277  // First pass: Allocate labels and record equivalences
278  for (auto it = cells.begin(); it != cells.end(); ++it) {
279  const internal::Connections<GridDim> seen =
280  internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
281  if (seen.nconn == 0) {
282  // Allocate new label
283  getCellLabel(*it) = ds.makeSet();
284  } else {
285  recordEquivalences(seen, ds);
286  // Set label for current cell
287  getCellLabel(*it) = seen.buf[0];
288  }
289  }
290 
291  // Second pass: Merge labels based on recorded equivalences
292  for (auto& cell : cells) {
293  Label& lbl = getCellLabel(cell);
294  lbl = ds.findSet(lbl);
295  }
296 }
297 
298 template <typename CellCollection, typename ClusterCollection,
299  size_t GridDim = 2>
300 ClusterCollection mergeClusters(CellCollection& cells) {
301  using Cell = typename CellCollection::value_type;
302  using Cluster = typename ClusterCollection::value_type;
303  internal::staticCheckGridDim<GridDim>();
304  internal::staticCheckCellType<Cell, GridDim>();
305  internal::staticCheckClusterType<Cluster&, const Cell&>();
306 
307  if constexpr (GridDim > 1) {
308  // Sort the cells by their cluster label, only needed if more than
309  // one spatial dimension
310  std::sort(cells.begin(), cells.end(), [](Cell& lhs, Cell& rhs) {
311  return getCellLabel(lhs) < getCellLabel(rhs);
312  });
313  }
314 
315  return internal::mergeClustersImpl<CellCollection, ClusterCollection>(cells);
316 }
317 
318 template <typename CellCollection, typename ClusterCollection, size_t GridDim,
319  typename Connect>
320 ClusterCollection createClusters(CellCollection& cells, Connect connect) {
321  using Cell = typename CellCollection::value_type;
322  using Cluster = typename ClusterCollection::value_type;
323  internal::staticCheckCellType<Cell, GridDim>();
324  internal::staticCheckClusterType<Cluster&, const Cell&>();
325  labelClusters<CellCollection, GridDim, Connect>(cells, connect);
326  return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
327 }
328 
329 } // namespace Acts::Ccl