Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MLTrackClassifier.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file MLTrackClassifier.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2020 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
11 #include <cassert>
12 #include <stdexcept>
13 
14 // prediction function
16  std::vector<float>& inputFeatures, double decisionThreshProb) const {
17  // check that the decision threshold is a probability
18  if (!((0. <= decisionThreshProb) && (decisionThreshProb <= 1.))) {
19  throw std::invalid_argument(
20  "predictTrackLabel: Decision threshold "
21  "probability is not in [0, 1].");
22  }
23 
24  // run the model over the input
25  std::vector<float> outputTensor = runONNXInference(inputFeatures);
26  // this is binary classification, so only need first value
27  float outputProbability = outputTensor[0];
28 
29  // the output layer computes how confident the network is that the track is a
30  // duplicate, so need to convert that to a label
31  if (outputProbability > decisionThreshProb) {
32  return TrackLabels::eDuplicate;
33  }
34  return TrackLabels::eGood;
35 }
36 
37 // function that checks if the predicted track label is duplicate
38 bool Acts::MLTrackClassifier::isDuplicate(std::vector<float>& inputFeatures,
39  double decisionThreshProb) const {
42  decisionThreshProb);
43  return predictedLabel == Acts::MLTrackClassifier::TrackLabels::eDuplicate;
44 }