Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
GaussianSumFitter.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file GaussianSumFitter.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
20 
21 #include <fstream>
22 
23 namespace Acts {
24 
25 namespace detail {
26 
30 template <typename T>
32 
33 template <>
35  : public std::true_type {};
36 
37 } // namespace detail
38 
50 template <typename propagator_t, typename bethe_heitler_approx_t,
51  typename traj_t>
53  GaussianSumFitter(propagator_t&& propagator, bethe_heitler_approx_t&& bha,
54  std::unique_ptr<const Logger> _logger =
56  : m_propagator(std::move(propagator)),
58  m_logger{std::move(_logger)},
59  m_actorLogger(m_logger->cloneWithSuffix("Actor")) {}
60 
63 
65  bethe_heitler_approx_t m_betheHeitlerApproximation;
66 
68  std::unique_ptr<const Logger> m_logger;
69  std::unique_ptr<const Logger> m_actorLogger;
70 
71  const Logger& logger() const { return *m_logger; }
72 
74  using GsfNavigator = typename propagator_t::Navigator;
75 
78 
80  template <typename source_link_it_t, typename start_parameters_t,
81  typename track_container_t, template <typename> class holder_t>
82  auto fit(source_link_it_t begin, source_link_it_t end,
83  const start_parameters_t& sParameters,
85  const std::vector<const Surface*>& sSequence,
87  const {
88  // Check if we have the correct navigator
89  static_assert(
90  std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
91 
92  // Initialize the forward propagation with the DirectNavigator
93  auto fwdPropInitializer = [&sSequence, this](const auto& opts) {
94  using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
95  using Aborters = AbortList<>;
96 
97  PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
98  opts.magFieldContext);
99 
100  propOptions.setPlainOptions(opts.propagatorPlainOptions);
101 
102  propOptions.actionList.template get<DirectNavigator::Initializer>()
103  .navSurfaces = sSequence;
104  propOptions.actionList.template get<GsfActor>()
105  .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
106 
107  return propOptions;
108  };
109 
110  // Initialize the backward propagation with the DirectNavigator
111  auto bwdPropInitializer = [&sSequence, this](const auto& opts) {
112  using Actors = ActionList<GsfActor, DirectNavigator::Initializer>;
113  using Aborters = AbortList<>;
114 
115  std::vector<const Surface*> backwardSequence(
116  std::next(sSequence.rbegin()), sSequence.rend());
117  backwardSequence.push_back(opts.referenceSurface);
118 
119  PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
120  opts.magFieldContext);
121 
122  propOptions.setPlainOptions(opts.propagatorPlainOptions);
123 
124  propOptions.actionList.template get<DirectNavigator::Initializer>()
125  .navSurfaces = std::move(backwardSequence);
126  propOptions.actionList.template get<GsfActor>()
127  .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
128 
129  return propOptions;
130  };
131 
132  return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
133  bwdPropInitializer, trackContainer);
134  }
135 
137  template <typename source_link_it_t, typename start_parameters_t,
138  typename track_container_t, template <typename> class holder_t>
139  auto fit(source_link_it_t begin, source_link_it_t end,
140  const start_parameters_t& sParameters,
143  const {
144  // Check if we have the correct navigator
145  static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);
146 
147  // Initialize the forward propagation with the DirectNavigator
148  auto fwdPropInitializer = [this](const auto& opts) {
149  using Actors = ActionList<GsfActor>;
150  using Aborters = AbortList<EndOfWorldReached>;
151 
152  PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
153  opts.magFieldContext);
154  propOptions.setPlainOptions(opts.propagatorPlainOptions);
155  propOptions.actionList.template get<GsfActor>()
156  .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
157 
158  return propOptions;
159  };
160 
161  // Initialize the backward propagation with the DirectNavigator
162  auto bwdPropInitializer = [this](const auto& opts) {
163  using Actors = ActionList<GsfActor>;
164  using Aborters = AbortList<EndOfWorldReached>;
165 
166  PropagatorOptions<Actors, Aborters> propOptions(opts.geoContext,
167  opts.magFieldContext);
168 
169  propOptions.setPlainOptions(opts.propagatorPlainOptions);
170 
171  propOptions.actionList.template get<GsfActor>()
172  .m_cfg.bethe_heitler_approx = &m_betheHeitlerApproximation;
173 
174  return propOptions;
175  };
176 
177  return fit_impl(begin, end, sParameters, options, fwdPropInitializer,
178  bwdPropInitializer, trackContainer);
179  }
180 
184  template <typename source_link_it_t, typename start_parameters_t,
185  typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
186  typename track_container_t, template <typename> class holder_t>
187  Acts::Result<
189  fit_impl(source_link_it_t begin, source_link_it_t end,
190  const start_parameters_t& sParameters,
192  const fwd_prop_initializer_t& fwdPropInitializer,
193  const bwd_prop_initializer_t& bwdPropInitializer,
195  const {
196  // return or abort utility
197  auto return_error_or_abort = [&](auto error) {
198  if (options.abortOnError) {
199  std::abort();
200  }
201  return error;
202  };
203 
204  // Define directions based on input propagation direction. This way we can
205  // refer to 'forward' and 'backward' regardless of the actual direction.
206  const auto gsfForward = options.propagatorPlainOptions.direction;
207  const auto gsfBackward = gsfForward.invert();
208 
209  // Check if the start parameters are on the start surface
210  auto intersectionStatusStartSurface =
211  sParameters.referenceSurface()
212  .intersect(GeometryContext{},
213  sParameters.position(GeometryContext{}),
214  sParameters.direction(), true)
215  .closest()
216  .status();
217 
218  if (intersectionStatusStartSurface != Intersection3D::Status::onSurface) {
219  ACTS_ERROR(
220  "Surface intersection of start parameters with bound-check failed");
221  return GsfError::StartParametersNotOnStartSurface;
222  }
223 
224  // To be able to find measurements later, we put them into a map
225  // We need to copy input SourceLinks anyway, so the map can own them.
226  ACTS_VERBOSE("Preparing " << std::distance(begin, end)
227  << " input measurements");
228  std::map<GeometryIdentifier, SourceLink> inputMeasurements;
229  for (auto it = begin; it != end; ++it) {
230  SourceLink sl = *it;
231  inputMeasurements.emplace(
232  options.extensions.surfaceAccessor(sl)->geometryId(), std::move(sl));
233  }
234 
235  ACTS_VERBOSE(
236  "Gsf: Final measurement map size: " << inputMeasurements.size());
237 
238  if (sParameters.covariance() == std::nullopt) {
239  return GsfError::StartParametersHaveNoCovariance;
240  }
241 
243  // Forward pass
245  ACTS_VERBOSE("+-----------------------------+");
246  ACTS_VERBOSE("| Gsf: Do forward propagation |");
247  ACTS_VERBOSE("+-----------------------------+");
248 
249  auto fwdResult = [&]() {
250  auto fwdPropOptions = fwdPropInitializer(options);
251 
252  // Catch the actor and set the measurements
253  auto& actor = fwdPropOptions.actionList.template get<GsfActor>();
254  actor.setOptions(options);
255  actor.m_cfg.inputMeasurements = &inputMeasurements;
256  actor.m_cfg.numberMeasurements = inputMeasurements.size();
257  actor.m_cfg.inReversePass = false;
258  actor.m_cfg.logger = m_actorLogger.get();
259 
260  fwdPropOptions.direction = gsfForward;
261 
262  // If necessary convert to MultiComponentBoundTrackParameters
263  using IsMultiParameters =
265 
266  typename propagator_t::template action_list_t_result_t<
268  decltype(fwdPropOptions.actionList)>
269  inputResult;
270 
271  auto& r = inputResult.template get<typename GsfActor::result_type>();
272 
273  r.fittedStates = &trackContainer.trackStateContainer();
274 
275  // This allows the initialization with single- and multicomponent start
276  // parameters
277  if constexpr (not IsMultiParameters::value) {
279  sParameters.referenceSurface().getSharedPtr(),
280  sParameters.parameters(), *sParameters.covariance(),
281  sParameters.particleHypothesis());
282 
283  return m_propagator.propagate(params, fwdPropOptions, false,
284  std::move(inputResult));
285  } else {
286  return m_propagator.propagate(sParameters, fwdPropOptions, false,
287  std::move(inputResult));
288  }
289  }();
290 
291  if (!fwdResult.ok()) {
292  return return_error_or_abort(fwdResult.error());
293  }
294 
295  auto& fwdGsfResult =
296  fwdResult->template get<typename GsfActor::result_type>();
297 
298  if (!fwdGsfResult.result.ok()) {
299  return return_error_or_abort(fwdGsfResult.result.error());
300  }
301 
302  if (fwdGsfResult.measurementStates == 0) {
303  return return_error_or_abort(GsfError::NoMeasurementStatesCreatedForward);
304  }
305 
306  ACTS_VERBOSE("Finished forward propagation");
307  ACTS_VERBOSE("- visited surfaces: " << fwdGsfResult.visitedSurfaces.size());
308  ACTS_VERBOSE("- processed states: " << fwdGsfResult.processedStates);
309  ACTS_VERBOSE("- measurement states: " << fwdGsfResult.measurementStates);
310 
311  std::size_t nInvalidBetheHeitler = fwdGsfResult.nInvalidBetheHeitler;
312 
314  // Backward pass
316  ACTS_VERBOSE("+------------------------------+");
317  ACTS_VERBOSE("| Gsf: Do backward propagation |");
318  ACTS_VERBOSE("+------------------------------+");
319 
320  auto bwdResult = [&]() {
321  auto bwdPropOptions = bwdPropInitializer(options);
322 
323  auto& actor = bwdPropOptions.actionList.template get<GsfActor>();
324  actor.setOptions(options);
325  actor.m_cfg.inputMeasurements = &inputMeasurements;
326  actor.m_cfg.inReversePass = true;
327  actor.m_cfg.logger = m_actorLogger.get();
328  actor.setOptions(options);
329 
330  bwdPropOptions.direction = gsfBackward;
331 
332  const Surface& target = options.referenceSurface
333  ? *options.referenceSurface
334  : sParameters.referenceSurface();
335 
336  using PM = TrackStatePropMask;
337 
338  typename propagator_t::template action_list_t_result_t<
340  decltype(bwdPropOptions.actionList)>
341  inputResult;
342 
343  // Unfortunately we must construct the result type here to be able to
344  // return an error code
345  using ResultType =
346  decltype(m_propagator.template propagate<
347  MultiComponentBoundTrackParameters, decltype(bwdPropOptions),
349  std::declval<MultiComponentBoundTrackParameters>(),
350  std::declval<Acts::Surface&>(),
351  std::declval<decltype(bwdPropOptions)>(),
352  std::declval<decltype(inputResult)>()));
353 
354  auto& r = inputResult.template get<typename GsfActor::result_type>();
355 
356  r.fittedStates = &trackContainer.trackStateContainer();
357 
358  assert(
359  (fwdGsfResult.lastMeasurementTip != MultiTrajectoryTraits::kInvalid &&
360  "tip is invalid"));
361 
362  auto proxy =
363  r.fittedStates->getTrackState(fwdGsfResult.lastMeasurementTip);
364  proxy.shareFrom(TrackStatePropMask::Filtered,
365  TrackStatePropMask::Smoothed);
366 
367  r.currentTip = fwdGsfResult.lastMeasurementTip;
368  r.visitedSurfaces.push_back(&proxy.referenceSurface());
369  r.surfacesVisitedBwdAgain.push_back(&proxy.referenceSurface());
370  r.measurementStates++;
371  r.processedStates++;
372 
373  const auto& params = *fwdGsfResult.lastMeasurementState;
374 
375  return m_propagator.template propagate<std::decay_t<decltype(params)>,
376  decltype(bwdPropOptions),
378  params, target, bwdPropOptions, std::move(inputResult));
379  }();
380 
381  if (!bwdResult.ok()) {
382  return return_error_or_abort(bwdResult.error());
383  }
384 
385  auto& bwdGsfResult =
386  bwdResult->template get<typename GsfActor::result_type>();
387 
388  if (!bwdGsfResult.result.ok()) {
389  return return_error_or_abort(bwdGsfResult.result.error());
390  }
391 
392  if (bwdGsfResult.measurementStates == 0) {
393  return return_error_or_abort(
394  GsfError::NoMeasurementStatesCreatedBackward);
395  }
396 
397  nInvalidBetheHeitler += bwdGsfResult.nInvalidBetheHeitler;
398 
399  if (nInvalidBetheHeitler > 0) {
400  ACTS_WARNING("Encountered "
401  << nInvalidBetheHeitler
402  << " cases where the material thickness exceeds the range "
403  "of the Bethe-Heitler-Approximation. Enable DEBUG output "
404  "for more information.");
405  }
406 
408  // Create Kalman Result
410  ACTS_VERBOSE("Gsf - States summary:");
411  ACTS_VERBOSE("- Fwd measurement states: " << fwdGsfResult.measurementStates
412  << ", holes: "
413  << fwdGsfResult.measurementHoles);
414  ACTS_VERBOSE("- Bwd measurement states: " << bwdGsfResult.measurementStates
415  << ", holes: "
416  << bwdGsfResult.measurementHoles);
417 
418  // TODO should this be warning level? it happens quite often... Investigate!
419  if (bwdGsfResult.measurementStates != fwdGsfResult.measurementStates) {
420  ACTS_DEBUG("Fwd and bwd measurement states do not match");
421  }
422 
423  // Go through the states and assign outliers / unset smoothed if surface not
424  // passed in backward pass
425  const auto& foundBwd = bwdGsfResult.surfacesVisitedBwdAgain;
426  std::size_t measurementStatesFinal = 0;
427 
428  for (auto state : fwdGsfResult.fittedStates->reverseTrackStateRange(
429  fwdGsfResult.currentTip)) {
430  const bool found = std::find(foundBwd.begin(), foundBwd.end(),
431  &state.referenceSurface()) != foundBwd.end();
432  if (not found && state.typeFlags().test(MeasurementFlag)) {
433  state.typeFlags().set(OutlierFlag);
434  state.typeFlags().reset(MeasurementFlag);
435  state.unset(TrackStatePropMask::Smoothed);
436  }
437 
438  measurementStatesFinal +=
439  static_cast<std::size_t>(state.typeFlags().test(MeasurementFlag));
440  }
441 
442  if (measurementStatesFinal == 0) {
443  return return_error_or_abort(GsfError::NoMeasurementStatesCreatedFinal);
444  }
445 
446  auto track = trackContainer.getTrack(trackContainer.addTrack());
447  track.tipIndex() = fwdGsfResult.lastMeasurementTip;
448 
449  if (options.referenceSurface) {
450  const auto& params = *bwdResult->endParameters;
451 
452  const auto [finalPars, finalCov] = Acts::reduceGaussianMixture(
453  params.components(), params.referenceSurface(),
454  options.stateReductionMethod, [](auto& t) {
455  return std::tie(std::get<0>(t), std::get<1>(t), *std::get<2>(t));
456  });
457 
458  track.parameters() = finalPars;
459  track.covariance() = finalCov;
460 
461  track.setReferenceSurface(params.referenceSurface().getSharedPtr());
462 
463  if (trackContainer.hasColumn(
465  ACTS_DEBUG("Add final multi-component state to track")
466  track.template component<GsfConstants::FinalMultiComponentState>(
468  }
469  }
470 
472 
473  track.nMeasurements() = measurementStatesFinal;
474  track.nHoles() = fwdGsfResult.measurementHoles;
475 
476  return track;
477  }
478 };
479 
480 } // namespace Acts