Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
GsfUtils.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file GsfUtils.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
17 
18 #include <array>
19 #include <cassert>
20 #include <cmath>
21 #include <cstddef>
22 #include <iomanip>
23 #include <map>
24 #include <numeric>
25 #include <ostream>
26 #include <tuple>
27 #include <vector>
28 
29 namespace Acts {
30 
32 constexpr static double s_normalizationTolerance = 1.e-4;
33 
34 namespace detail {
35 
36 template <typename component_range_t, typename projector_t>
37 bool weightsAreNormalized(const component_range_t &cmps,
38  const projector_t &proj,
39  double tol = s_normalizationTolerance) {
40  double sumOfWeights = 0.0;
41 
42  for (auto it = cmps.begin(); it != cmps.end(); ++it) {
43  sumOfWeights += proj(*it);
44  }
45 
46  return std::abs(sumOfWeights - 1.0) < tol;
47 }
48 
49 template <typename component_range_t, typename projector_t>
50 void normalizeWeights(component_range_t &cmps, const projector_t &proj) {
51  double sumOfWeights = 0.0;
52 
53  // we need decltype(auto) here to support proxy-types with reference
54  // semantics, otherwise there is a `cannot bind ... to ...` error
55  for (auto it = cmps.begin(); it != cmps.end(); ++it) {
56  decltype(auto) cmp = *it;
57  assert(std::isfinite(proj(cmp)) && "weight not finite in normalization");
58  sumOfWeights += proj(cmp);
59  }
60 
61  assert(sumOfWeights > 0 && "sum of weights is not > 0");
62 
63  for (auto it = cmps.begin(); it != cmps.end(); ++it) {
64  decltype(auto) cmp = *it;
65  proj(cmp) /= sumOfWeights;
66  }
67 }
68 
69 // A class that prints information about the state on construction and
70 // destruction, it also contains some assertions in the constructor and
71 // destructor. It can be removed without change of behaviour, since it only
72 // holds const references
73 template <typename propagator_state_t, typename stepper_t, typename navigator_t>
75  const propagator_state_t &m_state;
76  const stepper_t &m_stepper;
77  const navigator_t &m_navigator;
78  double m_p_initial;
79  const Logger &m_logger;
80 
81  const Logger &logger() const { return m_logger; }
82 
83  void print_component_stats() const {
84  std::size_t i = 0;
85  for (auto cmp : m_stepper.constComponentIterable(m_state.stepping)) {
86  auto getVector = [&](auto idx) {
87  return cmp.pars().template segment<3>(idx).transpose();
88  };
89  ACTS_VERBOSE(" #" << i++ << " pos: " << getVector(eFreePos0) << ", dir: "
90  << getVector(eFreeDir0) << ", weight: " << cmp.weight()
91  << ", status: " << cmp.status()
92  << ", qop: " << cmp.pars()[eFreeQOverP]
93  << ", det(cov): " << cmp.cov().determinant());
94  }
95  }
96 
97  void checks(bool onStart) const {
98  const auto cmps = m_stepper.constComponentIterable(m_state.stepping);
99  [[maybe_unused]] const bool allFinite =
100  std::all_of(cmps.begin(), cmps.end(),
101  [](auto cmp) { return std::isfinite(cmp.weight()); });
102  [[maybe_unused]] const bool allNormalized = detail::weightsAreNormalized(
103  cmps, [](const auto &cmp) { return cmp.weight(); });
104  [[maybe_unused]] const bool zeroComponents =
105  m_stepper.numberComponents(m_state.stepping) == 0;
106 
107  if (onStart) {
108  assert(not zeroComponents && "no cmps at the start");
109  assert(allFinite && "weights not finite at the start");
110  assert(allNormalized && "not normalized at the start");
111  } else {
112  assert(not zeroComponents && "no cmps at the end");
113  assert(allFinite && "weights not finite at the end");
114  assert(allNormalized && "not normalized at the end");
115  }
116  }
117 
118  public:
119  ScopedGsfInfoPrinterAndChecker(const propagator_state_t &state,
120  const stepper_t &stepper,
121  const navigator_t &navigator,
122  const Logger &logger)
123  : m_state(state),
124  m_stepper(stepper),
125  m_navigator(navigator),
126  m_p_initial(stepper.absoluteMomentum(state.stepping)),
127  m_logger{logger} {
128  // Some initial printing
129  checks(true);
130  ACTS_VERBOSE("Gsf step "
131  << state.stepping.steps << " at mean position "
132  << stepper.position(state.stepping).transpose()
133  << " with direction "
134  << stepper.direction(state.stepping).transpose()
135  << " and momentum " << stepper.absoluteMomentum(state.stepping)
136  << " and charge " << stepper.charge(state.stepping));
137  ACTS_VERBOSE("Propagation is in " << state.options.direction << " mode");
139  }
140 
142  if (m_navigator.currentSurface(m_state.navigation)) {
143  const auto p_final = m_stepper.absoluteMomentum(m_state.stepping);
144  ACTS_VERBOSE("Component status at end of step:");
146  ACTS_VERBOSE("Delta Momentum = " << std::setprecision(5)
147  << p_final - m_p_initial);
148  }
149  checks(false);
150  }
151 };
152 
154  const double *fullCalibrated, const double *fullCalibratedCovariance,
156  true>::Covariance predictedCovariance,
158  projector,
159  unsigned int calibratedSize);
160 
165 template <typename traj_t>
167  const traj_t &mt, const std::vector<MultiTrajectoryTraits::IndexType> &tips,
168  std::map<MultiTrajectoryTraits::IndexType, double> &weights) {
169  // Helper Function to compute detR
170 
171  // Find minChi2, this can be used to factor some things later in the
172  // exponentiation
173  const auto minChi2 =
174  mt.getTrackState(*std::min_element(tips.begin(), tips.end(),
175  [&](const auto &a, const auto &b) {
176  return mt.getTrackState(a).chi2() <
177  mt.getTrackState(b).chi2();
178  }))
179  .chi2();
180 
181  // Loop over the tips and compute new weights
182  for (auto tip : tips) {
183  const auto state = mt.getTrackState(tip);
184  const double chi2 = state.chi2() - minChi2;
185  const double detR = calculateDeterminant(
186  // This abuses an incorrectly sized vector / matrix to access the
187  // data pointer! This works (don't use the matrix as is!), but be
188  // careful!
189  state.template calibrated<MultiTrajectoryTraits::MeasurementSizeMax>()
190  .data(),
191  state
192  .template calibratedCovariance<
194  .data(),
195  state.predictedCovariance(), state.projector(), state.calibratedSize());
196 
197  const auto factor = std::sqrt(1. / detR) * std::exp(-0.5 * chi2);
198 
199  // If something is not finite here, just leave the weight as it is
200  if (std::isfinite(factor)) {
201  weights.at(tip) *= factor;
202  }
203  }
204 }
205 
208 enum class StatesType { ePredicted, eFiltered, eSmoothed };
209 
210 inline std::ostream &operator<<(std::ostream &os, StatesType type) {
211  constexpr static std::array names = {"predicted", "filtered", "smoothed"};
212  os << names[static_cast<int>(type)];
213  return os;
214 }
215 
219 template <StatesType type, typename traj_t>
221  const traj_t &mt;
222  const std::map<MultiTrajectoryTraits::IndexType, double> &weights;
223 
225  const auto proxy = mt.getTrackState(idx);
226  switch (type) {
227  case StatesType::ePredicted:
228  return std::make_tuple(weights.at(idx), proxy.predicted(),
229  proxy.predictedCovariance());
230  case StatesType::eFiltered:
231  return std::make_tuple(weights.at(idx), proxy.filtered(),
232  proxy.filteredCovariance());
233  case StatesType::eSmoothed:
234  return std::make_tuple(weights.at(idx), proxy.smoothed(),
235  proxy.smoothedCovariance());
236  }
237  }
238 };
239 
240 } // namespace detail
241 } // namespace Acts