Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SymmetricKlDistanceMatrix.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file SymmetricKlDistanceMatrix.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
16 
17 namespace Acts::detail {
18 
21 template <typename component_t, typename component_projector_t>
22 auto computeSymmetricKlDivergence(const component_t &a, const component_t &b,
23  const component_projector_t &proj) {
24  using namespace Acts;
25  const auto parsA = proj(a).boundPars[eBoundQOverP];
26  const auto parsB = proj(b).boundPars[eBoundQOverP];
27  const auto covA = proj(a).boundCov(eBoundQOverP, eBoundQOverP);
28  const auto covB = proj(b).boundCov(eBoundQOverP, eBoundQOverP);
29 
30  assert(covA != 0.0);
31  assert(std::isfinite(covA));
32  assert(covB != 0.0);
33  assert(std::isfinite(covB));
34 
35  const auto kl = covA * (1 / covB) + covB * (1 / covA) +
36  (parsA - parsB) * (1 / covA + 1 / covB) * (parsA - parsB);
37 
38  assert(kl >= 0.0 && "kl-divergence must be non-negative");
39 
40  return kl;
41 }
42 
43 template <typename component_t, typename component_projector_t,
44  typename angle_desc_t>
45 auto mergeComponents(const component_t &a, const component_t &b,
46  const component_projector_t &proj,
47  const angle_desc_t &angle_desc) {
48  assert(proj(a).weight >= 0.0 && proj(b).weight >= 0.0 &&
49  "non-positive weight");
50 
51  std::array range = {std::ref(proj(a)), std::ref(proj(b))};
52  const auto refProj = [](auto &c) {
53  return std::tie(c.get().weight, c.get().boundPars, c.get().boundCov);
54  };
55 
56  auto [mergedPars, mergedCov] =
57  gaussianMixtureMeanCov(range, refProj, angle_desc);
58 
59  component_t ret = a;
60  proj(ret).boundPars = mergedPars;
61  proj(ret).boundCov = mergedCov;
62  proj(ret).weight = proj(a).weight + proj(b).weight;
63 
64  return ret;
65 }
66 
69  using Array = Eigen::Array<Acts::ActsScalar, Eigen::Dynamic, 1>;
70  using Mask = Eigen::Array<bool, Eigen::Dynamic, 1>;
71 
74  std::vector<std::pair<std::size_t, std::size_t>> m_mapToPair;
75  std::size_t m_numberComponents;
76 
77  template <typename array_t, typename setter_t>
78  void setAssociated(std::size_t n, array_t &array, setter_t &&setter) {
79  const auto indexConst = (n - 1) * n / 2;
80 
81  // Rows
82  for (auto i = 0ul; i < n; ++i) {
83  array[indexConst + i] = setter(n, i);
84  }
85 
86  // Columns
87  for (auto i = n + 1; i < m_numberComponents; ++i) {
88  array[(i - 1) * i / 2 + n] = setter(n, i);
89  }
90  }
91 
92  public:
93  template <typename component_t, typename projector_t>
94  SymmetricKLDistanceMatrix(const std::vector<component_t> &cmps,
95  const projector_t &proj)
96  : m_distances(Array::Zero(cmps.size() * (cmps.size() - 1) / 2)),
97  m_mask(Mask::Ones(cmps.size() * (cmps.size() - 1) / 2)),
99  m_numberComponents(cmps.size()) {
100  for (auto i = 1ul; i < m_numberComponents; ++i) {
101  const auto indexConst = (i - 1) * i / 2;
102  for (auto j = 0ul; j < i; ++j) {
103  m_mapToPair.at(indexConst + j) = {i, j};
104  m_distances[indexConst + j] =
105  computeSymmetricKlDivergence(cmps[i], cmps[j], proj);
106  }
107  }
108  }
109 
110  auto at(std::size_t i, std::size_t j) const {
111  return m_distances[i * (i - 1) / 2 + j];
112  }
113 
114  template <typename component_t, typename projector_t>
115  void recomputeAssociatedDistances(std::size_t n,
116  const std::vector<component_t> &cmps,
117  const projector_t &proj) {
118  assert(cmps.size() == m_numberComponents && "size mismatch");
119 
120  setAssociated(n, m_distances, [&](std::size_t i, std::size_t j) {
121  return computeSymmetricKlDivergence(cmps[i], cmps[j], proj);
122  });
123  }
124 
125  void maskAssociatedDistances(std::size_t n) {
126  setAssociated(n, m_mask, [&](std::size_t, std::size_t) { return false; });
127  }
128 
129  auto minDistancePair() const {
130  auto min = std::numeric_limits<Acts::ActsScalar>::max();
131  std::size_t idx = 0;
132 
133  for (auto i = 0l; i < m_distances.size(); ++i) {
134  if (auto new_min = std::min(min, m_distances[i]);
135  m_mask[i] && new_min < min) {
136  min = new_min;
137  idx = i;
138  }
139  }
140 
141  return m_mapToPair.at(idx);
142  }
143 
144  friend std::ostream &operator<<(std::ostream &os,
145  const SymmetricKLDistanceMatrix &m) {
146  const auto prev_precision = os.precision();
147  const int width = 8;
148  const int prec = 2;
149 
150  os << "\n";
151  os << std::string(width, ' ') << " | ";
152  for (auto j = 0ul; j < m.m_numberComponents - 1; ++j) {
153  os << std::setw(width) << j << " ";
154  }
155  os << "\n";
156  os << std::string((width + 3) + (width + 2) * (m.m_numberComponents - 1),
157  '-');
158  os << "\n";
159 
160  for (auto i = 1ul; i < m.m_numberComponents; ++i) {
161  const auto indexConst = (i - 1) * i / 2;
162  os << std::setw(width) << i << " | ";
163  for (auto j = 0ul; j < i; ++j) {
164  os << std::setw(width) << std::setprecision(prec)
165  << m.m_distances[indexConst + j] << " ";
166  }
167  os << "\n";
168  }
169  os << std::setprecision(prev_precision);
170  return os;
171  }
172 };
173 
174 } // namespace Acts::detail