Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
PodioTrackStateContainer.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file PodioTrackStateContainer.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
15 #include "Acts/EventData/Types.hpp"
21 #include "ActsPodioEdm/BoundParametersCollection.h"
22 #include "ActsPodioEdm/JacobianCollection.h"
23 #include "ActsPodioEdm/TrackStateCollection.h"
24 #include "ActsPodioEdm/TrackStateInfo.h"
25 
26 #include <any>
27 #include <memory>
28 #include <stdexcept>
29 #include <tuple>
30 #include <type_traits>
31 
32 #include <podio/CollectionBase.h>
33 #include <podio/Frame.h>
34 
35 #include "podio/UserDataCollection.h"
36 
37 namespace Acts {
38 
39 class MutablePodioTrackStateContainer;
40 class ConstPodioTrackStateContainer;
41 
43  public:
44  using Parameters =
46  using Covariance =
48 
49  using ConstParameters =
51  using ConstCovariance =
53 
54  protected:
55  template <typename T>
56  static constexpr bool has_impl(T& instance, HashedString key,
57  TrackIndexType istate) {
59  using namespace Acts::HashedStringLiteral;
60  auto trackState = instance.m_collection->at(istate);
61  const auto& data = trackState.getData();
62  switch (key) {
63  case "predicted"_hash:
64  return data.ipredicted != kInvalid;
65  case "filtered"_hash:
66  return data.ifiltered != kInvalid;
67  case "smoothed"_hash:
68  return data.ismoothed != kInvalid;
69  case "calibrated"_hash:
70  return data.measdim != 0;
71  case "calibratedCov"_hash:
72  return data.measdim != 0;
73  case "jacobian"_hash:
74  return data.ijacobian != kInvalid;
75  case "projector"_hash:
76  return data.hasProjector;
77  case "uncalibratedSourceLink"_hash:
78  return data.uncalibratedIdentifier != PodioUtil::kNoIdentifier;
79  case "previous"_hash:
80  case "measdim"_hash:
81  case "referenceSurface"_hash:
82  case "chi2"_hash:
83  case "pathLength"_hash:
84  case "typeFlags"_hash:
85  return true;
86  default:
87  return instance.m_dynamic.find(key) != instance.m_dynamic.end();
88  }
89 
90  return false;
91  }
92 
93  template <bool EnsureConst, typename T>
94  static std::any component_impl(T& instance, HashedString key,
95  TrackIndexType istate) {
96  if constexpr (EnsureConst) {
97  static_assert(std::is_const_v<std::remove_reference_t<T>>,
98  "Is not const");
99  }
100  using namespace Acts::HashedStringLiteral;
101  auto trackState = instance.m_collection->at(istate);
102  std::conditional_t<EnsureConst, const ActsPodioEdm::TrackStateInfo*,
103  ActsPodioEdm::TrackStateInfo*>
104  dataPtr;
105  if constexpr (EnsureConst) {
106  dataPtr = &trackState.getData();
107  } else {
108  dataPtr = &trackState.data();
109  }
110  auto& data = *dataPtr;
111  switch (key) {
112  case "previous"_hash:
113  return &data.previous;
114  case "predicted"_hash:
115  return &data.ipredicted;
116  case "filtered"_hash:
117  return &data.ifiltered;
118  case "smoothed"_hash:
119  return &data.ismoothed;
120  case "projector"_hash:
121  return &data.projector;
122  case "measdim"_hash:
123  return &data.measdim;
124  case "chi2"_hash:
125  return &data.chi2;
126  case "pathLength"_hash:
127  return &data.pathLength;
128  case "typeFlags"_hash:
129  return &data.typeFlags;
130  default:
131  auto it = instance.m_dynamic.find(key);
132  if (it == instance.m_dynamic.end()) {
133  throw std::runtime_error("Unable to handle this component");
134  }
135  std::conditional_t<EnsureConst,
138  col = it->second.get();
139  assert(col && "Dynamic column is null");
140  return col->get(istate);
141  }
142  }
143 
144  template <typename T>
145  static constexpr bool hasColumn_impl(T& instance, HashedString key) {
146  using namespace Acts::HashedStringLiteral;
147  switch (key) {
148  case "predicted"_hash:
149  case "filtered"_hash:
150  case "smoothed"_hash:
151  case "jacobian"_hash:
152  case "projector"_hash:
153  case "previous"_hash:
154  case "uncalibratedSourceLink"_hash:
155  case "referenceSurface"_hash:
156  case "measdim"_hash:
157  case "chi2"_hash:
158  case "pathLength"_hash:
159  case "typeFlags"_hash:
160  return true;
161  default:
162  return instance.m_dynamic.find(key) != instance.m_dynamic.end();
163  }
164  }
165 
167  const PodioUtil::ConversionHelper& helper,
168  const ActsPodioEdm::TrackStateCollection& collection,
169  std::vector<std::shared_ptr<const Surface>>& surfaces) noexcept {
170  surfaces.reserve(collection.size());
171  for (ActsPodioEdm::TrackState trackState : collection) {
172  surfaces.push_back(PodioUtil::convertSurfaceFromPodio(
173  helper, trackState.getReferenceSurface()));
174  }
175  }
176 };
177 
178 template <>
180  : std::true_type {};
181 
184  public MultiTrajectory<ConstPodioTrackStateContainer> {
185  public:
187  const PodioUtil::ConversionHelper& helper,
188  const ActsPodioEdm::TrackStateCollection& trackStates,
189  const ActsPodioEdm::BoundParametersCollection& params,
190  const ActsPodioEdm::JacobianCollection& jacs)
191  : m_helper{helper},
192  m_collection{&trackStates},
193  m_params{&params},
194  m_jacs{&jacs} {
195  populateSurfaceBuffer(m_helper, *m_collection, m_surfaces);
196  }
197 
198  ConstPodioTrackStateContainer(const PodioUtil::ConversionHelper& helper,
199  const podio::Frame& frame,
200  const std::string& suffix = "")
201  : m_helper{helper},
202  m_collection{nullptr},
203  m_params{nullptr},
204  m_jacs{nullptr} {
205  std::string s = suffix.empty() ? suffix : "_" + suffix;
206 
207  std::vector<std::string> available = frame.getAvailableCollections();
208 
209  std::string trackStatesKey = "trackStates" + s;
210  std::string paramsKey = "trackStateParameters" + s;
211  std::string jacsKey = "trackStateJacobians" + s;
212 
213  if (std::find(available.begin(), available.end(), trackStatesKey) ==
214  available.end()) {
215  throw std::runtime_error{"Track state collection '" + trackStatesKey +
216  "'not found in frame"};
217  }
218 
219  if (std::find(available.begin(), available.end(), paramsKey) ==
220  available.end()) {
221  throw std::runtime_error{"Track state parameters collection '" +
222  paramsKey + "'not found in frame"};
223  }
224 
225  if (std::find(available.begin(), available.end(), jacsKey) ==
226  available.end()) {
227  throw std::runtime_error{"Track state jacobian collection '" + jacsKey +
228  "'not found in frame"};
229  }
230 
231  loadCollection<ActsPodioEdm::TrackStateCollection>(m_collection, frame,
232  trackStatesKey);
233  loadCollection<ActsPodioEdm::BoundParametersCollection>(m_params, frame,
234  paramsKey);
235  loadCollection<ActsPodioEdm::JacobianCollection>(m_jacs, frame, jacsKey);
236 
237  populateSurfaceBuffer(m_helper, *m_collection, m_surfaces);
238 
239  // let's find dynamic columns
240 
241  using load_type = std::unique_ptr<podio_detail::DynamicColumnBase> (*)(
242  const podio::CollectionBase*);
243 
244  using types =
245  std::tuple<int32_t, int64_t, uint32_t, uint64_t, float, double>;
246 
247  for (const auto& col : available) {
248  std::string prefix = trackStatesKey + "_extra__";
249  std::size_t p = col.find(prefix);
250  if (p == std::string::npos) {
251  continue;
252  }
253  std::string dynName = col.substr(prefix.size());
254  const podio::CollectionBase* coll = frame.get(col);
255 
256  std::unique_ptr<podio_detail::ConstDynamicColumnBase> up;
257 
258  std::apply(
259  [&](auto... args) {
260  auto inner = [&](auto arg) {
261  if (up) {
262  return;
263  }
264  using T = decltype(arg);
265  const auto* dyn =
266  dynamic_cast<const podio::UserDataCollection<T>*>(coll);
267  if (dyn == nullptr) {
268  return;
269  }
270  up = std::make_unique<podio_detail::ConstDynamicColumn<T>>(
271  dynName, *dyn);
272  };
273 
274  ((inner(args)), ...);
275  },
276  types{});
277 
278  if (!up) {
279  throw std::runtime_error{"Dynamic column '" + dynName +
280  "' is not of allowed type"};
281  }
282 
283  m_dynamic.insert({hashString(dynName), std::move(up)});
284  }
285  }
286 
287  private:
288  template <typename collection_t>
289  static void loadCollection(collection_t const*& dest,
290  const podio::Frame& frame,
291  const std::string& key) {
292  const auto* collection = frame.get(key);
293 
294  if (const auto* d = dynamic_cast<const collection_t*>(collection);
295  d != nullptr) {
296  dest = d;
297  } else {
298  throw std::runtime_error{"Unable to get collection " + key};
299  }
300  }
301 
302  public:
303  ConstParameters parameters_impl(IndexType istate) const {
304  return ConstParameters{m_params->at(istate).getData().values.data()};
305  }
306 
307  ConstCovariance covariance_impl(IndexType istate) const {
308  return ConstCovariance{m_params->at(istate).getData().covariance.data()};
309  }
310 
311  ConstCovariance jacobian_impl(IndexType istate) const {
312  IndexType ijacobian = m_collection->at(istate).getData().ijacobian;
313  return ConstCovariance{m_jacs->at(ijacobian).getData().values.data()};
314  }
315 
316  template <size_t measdim>
317  ConstTrackStateProxy::Measurement<measdim> measurement_impl(
318  IndexType index) const {
319  return ConstTrackStateProxy::Measurement<measdim>{
320  m_collection->at(index).getData().measurement.data()};
321  }
322 
323  template <size_t measdim>
324  ConstTrackStateProxy::MeasurementCovariance<measdim>
325  measurementCovariance_impl(IndexType index) const {
326  return ConstTrackStateProxy::MeasurementCovariance<measdim>{
327  m_collection->at(index).getData().measurementCovariance.data()};
328  }
329 
330  IndexType size_impl() const { return m_collection->size(); }
331 
332  std::any component_impl(HashedString key, IndexType istate) const {
333  return PodioTrackStateContainerBase::component_impl<true>(*this, key,
334  istate);
335  }
336 
337  constexpr bool hasColumn_impl(HashedString key) const {
339  }
340 
341  constexpr bool has_impl(HashedString key, IndexType istate) const {
342  return PodioTrackStateContainerBase::has_impl(*this, key, istate);
343  }
344 
345  MultiTrajectoryTraits::IndexType calibratedSize_impl(IndexType istate) const {
346  return m_collection->at(istate).getData().measdim;
347  }
348 
349  SourceLink getUncalibratedSourceLink_impl(IndexType istate) const {
350  return m_helper.get().identifierToSourceLink(
351  m_collection->at(istate).getData().uncalibratedIdentifier);
352  }
353 
354  const Surface* referenceSurface_impl(IndexType istate) const {
355  return m_surfaces.at(istate).get();
356  }
357 
358  private:
359  friend class PodioTrackStateContainerBase;
360 
361  std::reference_wrapper<const PodioUtil::ConversionHelper> m_helper;
362  const ActsPodioEdm::TrackStateCollection* m_collection;
363  const ActsPodioEdm::BoundParametersCollection* m_params;
364  const ActsPodioEdm::JacobianCollection* m_jacs;
365  std::vector<std::shared_ptr<const Surface>> m_surfaces;
366 
367  std::unordered_map<HashedString,
368  std::unique_ptr<podio_detail::ConstDynamicColumnBase>>
370 };
371 
373  "MutablePodioTrackStateContainer should not be read-only");
374 
375 ACTS_STATIC_CHECK_CONCEPT(ConstMultiTrajectoryBackend,
377 
378 template <>
380  : std::false_type {};
381 
384  public MultiTrajectory<MutablePodioTrackStateContainer> {
385  public:
387  : m_helper{helper} {
388  m_collection = std::make_unique<ActsPodioEdm::TrackStateCollection>();
389  m_jacs = std::make_unique<ActsPodioEdm::JacobianCollection>();
390  m_params = std::make_unique<ActsPodioEdm::BoundParametersCollection>();
391 
392  populateSurfaceBuffer(m_helper, *m_collection, m_surfaces);
393  }
394 
396  return ConstParameters{m_params->at(istate).getData().values.data()};
397  }
398 
400  return Parameters{m_params->at(istate).data().values.data()};
401  }
402 
404  return ConstCovariance{m_params->at(istate).getData().covariance.data()};
405  }
406 
408  return Covariance{m_params->at(istate).data().covariance.data()};
409  }
410 
412  IndexType ijacobian = m_collection->at(istate).getData().ijacobian;
413  return ConstCovariance{m_jacs->at(ijacobian).getData().values.data()};
414  }
415 
417  IndexType ijacobian = m_collection->at(istate).getData().ijacobian;
418  return Covariance{m_jacs->at(ijacobian).data().values.data()};
419  }
420 
421  template <size_t measdim>
423  IndexType index) const {
425  m_collection->at(index).getData().measurement.data()};
426  }
427 
428  template <size_t measdim>
431  m_collection->at(index).data().measurement.data()};
432  }
433 
434  template <size_t measdim>
435  ConstTrackStateProxy::MeasurementCovariance<measdim>
438  m_collection->at(index).getData().measurementCovariance.data()};
439  }
440 
441  template <size_t measdim>
443  IndexType index) {
445  m_collection->at(index).data().measurementCovariance.data()};
446  }
447 
448  IndexType size_impl() const { return m_collection->size(); }
449 
450  std::any component_impl(HashedString key, IndexType istate) const {
451  return PodioTrackStateContainerBase::component_impl<true>(*this, key,
452  istate);
453  }
454 
455  std::any component_impl(HashedString key, IndexType istate) {
456  return PodioTrackStateContainerBase::component_impl<false>(*this, key,
457  istate);
458  }
459 
460  constexpr bool hasColumn_impl(HashedString key) const {
462  }
463 
464  constexpr bool has_impl(HashedString key, IndexType istate) const {
465  return PodioTrackStateContainerBase::has_impl(*this, key, istate);
466  }
467 
468  IndexType addTrackState_impl(
469  TrackStatePropMask mask = TrackStatePropMask::All,
470  TrackIndexType iprevious = kTrackIndexInvalid) {
471  auto trackState = m_collection->create();
472  auto& data = trackState.data();
473  data.previous = iprevious;
474  data.ipredicted = kInvalid;
475  data.ifiltered = kInvalid;
476  data.ismoothed = kInvalid;
477  data.ijacobian = kInvalid;
478  trackState.referenceSurface().surfaceType = PodioUtil::kNoSurface;
479 
480  if (ACTS_CHECK_BIT(mask, TrackStatePropMask::Predicted)) {
481  m_params->create();
482  data.ipredicted = m_params->size() - 1;
483  }
484  if (ACTS_CHECK_BIT(mask, TrackStatePropMask::Filtered)) {
485  m_params->create();
486  data.ifiltered = m_params->size() - 1;
487  }
488  if (ACTS_CHECK_BIT(mask, TrackStatePropMask::Smoothed)) {
489  m_params->create();
490  data.ismoothed = m_params->size() - 1;
491  }
493  m_jacs->create();
494  data.ijacobian = m_jacs->size() - 1;
495  }
496  data.measdim = 0;
497  data.hasProjector = false;
498  if (ACTS_CHECK_BIT(mask, TrackStatePropMask::Calibrated)) {
499  data.hasProjector = true;
500  }
501  m_surfaces.emplace_back();
502 
503  data.uncalibratedIdentifier = PodioUtil::kNoIdentifier;
504  assert(m_collection->size() == m_surfaces.size() &&
505  "Inconsistent surface buffer");
506 
507  for (const auto& [key, vec] : m_dynamic) {
508  vec->add();
509  }
510 
511  return m_collection->size() - 1;
512  }
513 
514  void shareFrom_impl(TrackIndexType iself, TrackIndexType iother,
515  TrackStatePropMask shareSource,
516  TrackStatePropMask shareTarget) {
517  auto& self = m_collection->at(iself).data();
518  auto& other = m_collection->at(iother).data();
519 
520  assert(ACTS_CHECK_BIT(getTrackState(iother).getMask(), shareSource) &&
521  "Source has incompatible allocation");
522 
523  using PM = TrackStatePropMask;
524 
525  IndexType sourceIndex{kInvalid};
526  switch (shareSource) {
527  case PM::Predicted:
528  sourceIndex = other.ipredicted;
529  break;
530  case PM::Filtered:
531  sourceIndex = other.ifiltered;
532  break;
533  case PM::Smoothed:
534  sourceIndex = other.ismoothed;
535  break;
536  case PM::Jacobian:
537  sourceIndex = other.ijacobian;
538  break;
539  default:
540  throw std::domain_error{"Unable to share this component"};
541  }
542 
543  assert(sourceIndex != kInvalid);
544 
545  switch (shareTarget) {
546  case PM::Predicted:
547  assert(shareSource != PM::Jacobian);
548  self.ipredicted = sourceIndex;
549  break;
550  case PM::Filtered:
551  assert(shareSource != PM::Jacobian);
552  self.ifiltered = sourceIndex;
553  break;
554  case PM::Smoothed:
555  assert(shareSource != PM::Jacobian);
556  self.ismoothed = sourceIndex;
557  break;
558  case PM::Jacobian:
559  assert(shareSource == PM::Jacobian);
560  self.ijacobian = sourceIndex;
561  break;
562  default:
563  throw std::domain_error{"Unable to share this component"};
564  }
565  }
566 
568  auto& data = m_collection->at(istate).data();
569  switch (target) {
570  case TrackStatePropMask::Predicted:
571  data.ipredicted = kInvalid;
572  break;
573  case TrackStatePropMask::Filtered:
574  data.ifiltered = kInvalid;
575  break;
576  case TrackStatePropMask::Smoothed:
577  data.ismoothed = kInvalid;
578  break;
580  data.ijacobian = kInvalid;
581  break;
582  case TrackStatePropMask::Calibrated:
583  data.measdim = 0;
584  break;
585  default:
586  throw std::domain_error{"Unable to unset this component"};
587  }
588  }
589 
590  void clear_impl() {
591  m_collection->clear();
592  m_params->clear();
593  m_surfaces.clear();
594  for (const auto& [key, vec] : m_dynamic) {
595  vec->clear();
596  }
597  }
598 
599  template <typename T>
600  constexpr void addColumn_impl(const std::string& key) {
601  m_dynamic.insert({hashString(key),
602  std::make_unique<podio_detail::DynamicColumn<T>>(key)});
603  }
604 
605  void allocateCalibrated_impl(IndexType istate, size_t measdim) {
606  assert(measdim > 0 && "Zero measdim not supported");
607  auto& data = m_collection->at(istate).data();
608  data.measdim = measdim;
609  }
610 
611  void setUncalibratedSourceLink_impl(IndexType istate,
612  const SourceLink& sourceLink) {
614  m_helper.get().sourceLinkToIdentifier(sourceLink);
615  m_collection->at(istate).data().uncalibratedIdentifier = id;
616  }
617 
618  void setReferenceSurface_impl(IndexType istate,
619  std::shared_ptr<const Surface> surface) {
620  auto trackState = m_collection->at(istate);
621  trackState.setReferenceSurface(
622  PodioUtil::convertSurfaceToPodio(m_helper, *surface));
623  m_surfaces.at(istate) = std::move(surface);
624  }
625 
627  return m_collection->at(istate).getData().measdim;
628  }
629 
631  return m_helper.get().identifierToSourceLink(
632  m_collection->at(istate).getData().uncalibratedIdentifier);
633  }
634 
635  const Surface* referenceSurface_impl(IndexType istate) const {
636  return m_surfaces.at(istate).get();
637  }
638 
639  void releaseInto(podio::Frame& frame, const std::string& suffix = "") {
640  std::string s = suffix;
641  if (!s.empty()) {
642  s = "_" + s;
643  }
644  frame.put(std::move(m_collection), "trackStates" + s);
645  frame.put(std::move(m_params), "trackStateParameters" + s);
646  frame.put(std::move(m_jacs), "trackStateJacobians" + s);
647  m_surfaces.clear();
648 
649  for (const auto& [key, col] : m_dynamic) {
650  col->releaseInto(frame, "trackStates" + s + "_extra__");
651  }
652  }
653 
654  private:
657 
658  std::reference_wrapper<PodioUtil::ConversionHelper> m_helper;
659  std::unique_ptr<ActsPodioEdm::TrackStateCollection> m_collection;
660  std::unique_ptr<ActsPodioEdm::BoundParametersCollection> m_params;
661  std::unique_ptr<ActsPodioEdm::JacobianCollection> m_jacs;
662  std::vector<std::shared_ptr<const Surface>> m_surfaces;
663 
664  std::unordered_map<HashedString,
665  std::unique_ptr<podio_detail::DynamicColumnBase>>
667 };
668 
669 static_assert(
671  "MutablePodioTrackStateContainer should not be read-only");
672 
674  "MutablePodioTrackStateContainer should not be read-only");
675 
676 ACTS_STATIC_CHECK_CONCEPT(MutableMultiTrajectoryBackend,
678 
679 // ConstPodioTrackStateContainer::ConstPodioTrackStateContainer(
680 // MutablePodioTrackStateContainer&& other)
681 // : m_helper{other.m_helper},
682 // m_collection{std::move(other.m_collection)},
683 // m_params{std::move(other.m_params)},
684 // m_jacs{std::move(other.m_jacs)},
685 // m_surfaces{std::move(other.m_surfaces)} {}
686 
687 // ConstPodioTrackStateContainer::ConstPodioTrackStateContainer(
688 // const MutablePodioTrackStateContainer& other)
689 // : m_helper{other.m_helper},
690 // m_surfaces{other.m_surfaces.begin(), other.m_surfaces.end()} {
691 // for (auto src : *other.m_collection) {
692 // auto dst = m_collection->create();
693 // dst = src.clone();
694 // }
695 // for (auto src : *other.m_params) {
696 // auto dst = m_params->create();
697 // dst = src.clone();
698 // }
699 // for (auto src : *other.m_jacs) {
700 // auto dst = m_jacs->create();
701 // dst = src.clone();
702 // }
703 // }
704 
705 } // namespace Acts