Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
AmbiguityTrackClassifier.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file AmbiguityTrackClassifier.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
15 
16 #include <map>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include <onnxruntime_cxx_api.h>
21 
22 namespace Acts {
23 
26  public:
30  AmbiguityTrackClassifier(const char* modelPath)
31  : m_env(ORT_LOGGING_LEVEL_WARNING, "MLClassifier"),
32  m_duplicateClassifier(m_env, modelPath) {}
33 
39  template <typename track_container_t, typename traj_t,
40  template <typename> class holder_t>
41  std::vector<std::vector<float>> inferScores(
42  std::unordered_map<int, std::vector<int>>& clusters,
44  const {
45  // Compute the number of entry (since it is smaller than the number of
46  // track)
47  int trackNb = 0;
48  for (const auto& [_, val] : clusters) {
49  trackNb += val.size();
50  }
51  // Input of the neural network
52  Acts::NetworkBatchInput networkInput(trackNb, 8);
53  int inputID = 0;
54  // Get the input feature of the network for all the tracks
55  for (const auto& [key, val] : clusters) {
56  for (const auto& trackID : val) {
57  auto track = tracks.getTrack(trackID);
59  tracks.trackStateContainer(), track.tipIndex());
60  networkInput(inputID, 0) = trajState.nStates;
61  networkInput(inputID, 1) = trajState.nMeasurements;
62  networkInput(inputID, 2) = trajState.nOutliers;
63  networkInput(inputID, 3) = trajState.nHoles;
64  networkInput(inputID, 4) = trajState.NDF;
65  networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
66  (trajState.NDF != 0 ? trajState.NDF : 1);
67  networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
68  networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
69  inputID++;
70  }
71  }
72  // Use the network to compute a score for all the tracks.
73  std::vector<std::vector<float>> outputTensor =
75  return outputTensor;
76  }
77 
83  std::vector<int> trackSelection(
84  std::unordered_map<int, std::vector<int>>& clusters,
85  std::vector<std::vector<float>>& outputTensor) const {
86  std::vector<int> goodTracks;
87  int iOut = 0;
88  // Loop over all the cluster and only keep the track with the highest score
89  // in each cluster
90  for (const auto& [key, val] : clusters) {
91  int bestTrackID = 0;
92  float bestTrackScore = 0;
93  for (const auto& track : val) {
94  if (outputTensor[iOut][0] > bestTrackScore) {
95  bestTrackScore = outputTensor[iOut][0];
96  bestTrackID = track;
97  }
98  iOut++;
99  }
100  goodTracks.push_back(bestTrackID);
101  }
102  return goodTracks;
103  }
104 
110  template <typename track_container_t, typename traj_t,
111  template <typename> class holder_t>
112  std::vector<int> solveAmbuguity(
113  std::unordered_map<int, std::vector<int>>& clusters,
115  const {
116  std::vector<std::vector<float>> outputTensor =
117  inferScores(clusters, tracks);
118  std::vector<int> goodTracks = trackSelection(clusters, outputTensor);
119 
120  return goodTracks;
121  }
122 
123  private:
124  // ONNX environment
125  Ort::Env m_env;
126  // ONNX model for the duplicate neural network
128 };
129 
130 } // namespace Acts