Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TrackFindingAlgorithmExaTrkX.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file TrackFindingAlgorithmExaTrkX.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2022 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
13 #include "Acts/Utilities/Zip.hpp"
19 
20 #include <numeric>
21 
22 using namespace ActsExamples;
23 using namespace Acts::UnitLiterals;
24 
25 namespace {
26 
27 class ExamplesEdmHook : public Acts::ExaTrkXHook {
28  double m_targetPT = 0.5_GeV;
29  std::size_t m_targetSize = 3;
30 
31  std::unique_ptr<const Acts::Logger> m_logger;
32  std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
33  std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
34 
35  const Acts::Logger& logger() const { return *m_logger; }
36 
37  struct HitInfo {
38  std::size_t spacePointIndex;
39  int32_t hitIndex;
40  };
41 
42  public:
43  ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
44  const IndexMultimap<Index>& measHitMap,
45  const SimHitContainer& simhits,
47  std::size_t targetMinHits, double targetMinPT,
48  const Acts::Logger& logger)
49  : m_targetPT(targetMinPT),
50  m_targetSize(targetMinHits),
51  m_logger(logger.clone("MetricsHook")) {
52  // Associate tracks to graph, collect momentum
53  std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
54 
55  for (auto i = 0ul; i < spacepoints.size(); ++i) {
56  const auto measId = spacepoints[i]
57  .sourceLinks()[0]
58  .template get<IndexSourceLink>()
59  .index();
60 
61  auto [a, b] = measHitMap.equal_range(measId);
62  for (auto it = a; it != b; ++it) {
63  const auto& hit = *simhits.nth(it->second);
64 
65  tracks[hit.particleId()].push_back({i, hit.index()});
66  }
67  }
68 
69  // Collect edges for truth graph and target graph
70  std::vector<int64_t> truthGraph;
71  std::vector<int64_t> targetGraph;
72 
73  for (auto& [pid, track] : tracks) {
74  // Sort by hit index, so the edges are connected correctly
75  std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
76  return a.hitIndex < b.hitIndex;
77  });
78 
79  auto found = particles.find(pid);
80  if (found == particles.end()) {
81  ACTS_WARNING("Did not find " << pid << ", skip track");
82  continue;
83  }
84 
85  for (auto i = 0ul; i < track.size() - 1; ++i) {
86  truthGraph.push_back(track[i].spacePointIndex);
87  truthGraph.push_back(track[i + 1].spacePointIndex);
88 
89  if (found->transverseMomentum() > m_targetPT &&
90  track.size() >= m_targetSize) {
91  targetGraph.push_back(track[i].spacePointIndex);
92  targetGraph.push_back(track[i + 1].spacePointIndex);
93  }
94  }
95  }
96 
97  m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
98  truthGraph, logger.clone());
99  m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
100  targetGraph, logger.clone());
101  }
102 
103  ~ExamplesEdmHook() {}
104 
105  void operator()(const std::any& nodes, const std::any& edges) const override {
106  ACTS_INFO("Metrics for total graph:");
107  (*m_truthGraphHook)(nodes, edges);
108  ACTS_INFO("Metrics for target graph (pT > "
109  << m_targetPT / Acts::UnitConstants::GeV
110  << " GeV, nHits >= " << m_targetSize << "):");
111  (*m_targetGraphHook)(nodes, edges);
112  }
113 };
114 
115 } // namespace
116 
119  : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
120  m_cfg(std::move(config)),
121  m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
122  m_cfg.trackBuilder, logger().clone()) {
123  if (m_cfg.inputSpacePoints.empty()) {
124  throw std::invalid_argument("Missing spacepoint input collection");
125  }
126  if (m_cfg.outputProtoTracks.empty()) {
127  throw std::invalid_argument("Missing protoTrack output collection");
128  }
129 
130  // Sanitizer run with dummy input to detect configuration issues
131  // TODO This would be quite helpful I think, but currently it does not work
132  // in general because the stages do not expose the number of node features.
133  // However, this must be addressed anyway when we also want to allow to
134  // configure this more flexible with e.g. cluster information as input. So
135  // for now, we disable this.
136 #if 0
137  if( m_cfg.sanitize ) {
138  Eigen::VectorXf dummyInput = Eigen::VectorXf::Random(3 * 15);
139  std::vector<float> dummyInputVec(dummyInput.data(),
140  dummyInput.data() + dummyInput.size());
141  std::vector<int> spacepointIDs;
142  std::iota(spacepointIDs.begin(), spacepointIDs.end(), 0);
143 
144  runPipeline(dummyInputVec, spacepointIDs);
145  }
146 #endif
147 
151 
155 
156  // reserve space for timing
157  m_timing.classifierTimes.resize(
158  m_cfg.edgeClassifiers.size(),
159  decltype(m_timing.classifierTimes)::value_type{0.f});
160 }
161 
163 enum feat : std::size_t {
164  eR = 0,
166  eZ,
171 };
172 
174  const ActsExamples::AlgorithmContext& ctx) const {
175  // Read input data
176  auto spacepoints = m_inputSpacePoints(ctx);
177 
178  auto hook = std::make_unique<Acts::ExaTrkXHook>();
179  if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
180  hook = std::make_unique<ExamplesEdmHook>(
181  spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
182  m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
183  logger());
184  }
185 
186  std::optional<ClusterContainer> clusters;
187  if (m_inputClusters.isInitialized()) {
188  clusters = m_inputClusters(ctx);
189  }
190 
191  // Convert Input data to a list of size [num_measurements x
192  // measurement_features]
193  const std::size_t numSpacepoints = spacepoints.size();
194  const std::size_t numFeatures = clusters ? 7 : 3;
195  ACTS_INFO("Received " << numSpacepoints << " spacepoints");
196 
197  std::vector<float> features(numSpacepoints * numFeatures);
198  std::vector<int> spacepointIDs;
199 
200  spacepointIDs.reserve(spacepoints.size());
201 
202  double sumCells = 0.0;
203  double sumActivation = 0.0;
204 
205  for (auto i = 0ul; i < numSpacepoints; ++i) {
206  const auto& sp = spacepoints[i];
207 
208  // I would prefer to use a std::span or boost::span here once available
209  float* featurePtr = features.data() + i * numFeatures;
210 
211  // For now just take the first index since does require one single index
212  // per spacepoint
213  const auto& sl = sp.sourceLinks()[0].template get<IndexSourceLink>();
214  spacepointIDs.push_back(sl.index());
215 
216  featurePtr[eR] = std::hypot(sp.x(), sp.y()) / m_cfg.rScale;
217  featurePtr[ePhi] = std::atan2(sp.y(), sp.x()) / m_cfg.phiScale;
218  featurePtr[eZ] = sp.z() / m_cfg.zScale;
219 
220  if (clusters) {
221  const auto& cluster = clusters->at(sl.index());
222  const auto& chnls = cluster.channels;
223 
224  featurePtr[eCellCount] = cluster.channels.size() / m_cfg.cellCountScale;
225  featurePtr[eCellSum] =
226  std::accumulate(chnls.begin(), chnls.end(), 0.0,
227  [](double s, const Cluster::Cell& c) {
228  return s + c.activation;
229  }) /
230  m_cfg.cellSumScale;
231  featurePtr[eClusterX] = cluster.sizeLoc0 / m_cfg.clusterXScale;
232  featurePtr[eClusterY] = cluster.sizeLoc1 / m_cfg.clusterYScale;
233 
234  sumCells += featurePtr[eCellCount];
235  sumActivation += featurePtr[eCellSum];
236  }
237  }
238 
239  ACTS_DEBUG("Avg cell count: " << sumCells / spacepoints.size());
240  ACTS_DEBUG("Avg activation: " << sumActivation / sumCells);
241 
242  // Run the pipeline
243  const auto trackCandidates = [&]() {
244  const int deviceHint = -1;
245  std::lock_guard<std::mutex> lock(m_mutex);
246 
247  Acts::ExaTrkXTiming timing;
248  auto res =
249  m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
250 
251  m_timing.graphBuildingTime(timing.graphBuildingTime.count());
252 
253  assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
254  for (auto [aggr, a] :
255  Acts::zip(m_timing.classifierTimes, timing.classifierTimes)) {
256  aggr(a.count());
257  }
258 
259  m_timing.trackBuildingTime(timing.trackBuildingTime.count());
260 
261  return res;
262  }();
263 
264  ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
265  << " candidates");
266 
267  // Make the prototracks
268  std::vector<ProtoTrack> protoTracks;
269  protoTracks.reserve(trackCandidates.size());
270  for (auto& x : trackCandidates) {
271  ProtoTrack onetrack;
272  std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
273  protoTracks.push_back(std::move(onetrack));
274  }
275 
276  ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
277  m_outputProtoTracks(ctx, std::move(protoTracks));
278 
280 }
281 
283  namespace ba = boost::accumulators;
284 
285  ACTS_INFO("Exa.TrkX timing info");
286  {
287  const auto& t = m_timing.graphBuildingTime;
288  ACTS_INFO("- graph building: " << ba::mean(t) << " +- "
289  << std::sqrt(ba::variance(t)));
290  }
291  for (const auto& t : m_timing.classifierTimes) {
292  ACTS_INFO("- classifier: " << ba::mean(t) << " +- "
293  << std::sqrt(ba::variance(t)));
294  }
295  {
296  const auto& t = m_timing.trackBuildingTime;
297  ACTS_INFO("- track building: " << ba::mean(t) << " +- "
298  << std::sqrt(ba::variance(t)));
299  }
300 
301  return {};
302 }