Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExaTrkXTrackFinding.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file ExaTrkXTrackFinding.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
21 
22 #include <memory>
23 
24 #include <pybind11/functional.h>
25 #include <pybind11/pybind11.h>
26 #include <pybind11/stl.h>
27 
28 namespace py = pybind11;
29 
30 using namespace ActsExamples;
31 using namespace Acts;
32 
33 namespace Acts::Python {
34 
36  auto [m, mex] = ctx.get("main", "examples");
37 
38  {
40  auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
41  }
42  {
44  auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
45  }
46  {
47  using C = Acts::TrackBuildingBase;
48  auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
49  }
50 
51 #ifdef ACTS_EXATRKX_TORCH_BACKEND
52  {
53  using Alg = Acts::TorchMetricLearning;
54  using Config = Alg::Config;
55 
56  auto alg =
57  py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
58  mex, "TorchMetricLearning")
59  .def(py::init([](const Config &c, Logging::Level lvl) {
60  return std::make_shared<Alg>(
61  c, getDefaultLogger("MetricLearning", lvl));
62  }),
63  py::arg("config"), py::arg("level"))
64  .def_property_readonly("config", &Alg::config);
65 
66  auto c = py::class_<Config>(alg, "Config").def(py::init<>());
68  ACTS_PYTHON_MEMBER(modelPath);
69  ACTS_PYTHON_MEMBER(numFeatures);
70  ACTS_PYTHON_MEMBER(embeddingDim);
71  ACTS_PYTHON_MEMBER(rVal);
72  ACTS_PYTHON_MEMBER(knnVal);
74  }
75  {
76  using Alg = Acts::TorchEdgeClassifier;
77  using Config = Alg::Config;
78 
79  auto alg =
80  py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
81  mex, "TorchEdgeClassifier")
82  .def(py::init([](const Config &c, Logging::Level lvl) {
83  return std::make_shared<Alg>(
84  c, getDefaultLogger("EdgeClassifier", lvl));
85  }),
86  py::arg("config"), py::arg("level"))
87  .def_property_readonly("config", &Alg::config);
88 
89  auto c = py::class_<Config>(alg, "Config").def(py::init<>());
91  ACTS_PYTHON_MEMBER(modelPath);
92  ACTS_PYTHON_MEMBER(numFeatures);
93  ACTS_PYTHON_MEMBER(cut);
94  ACTS_PYTHON_MEMBER(nChunks);
95  ACTS_PYTHON_MEMBER(undirected);
97  }
98  {
99  using Alg = Acts::BoostTrackBuilding;
100 
101  auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
102  mex, "BoostTrackBuilding")
103  .def(py::init([](Logging::Level lvl) {
104  return std::make_shared<Alg>(
105  getDefaultLogger("EdgeClassifier", lvl));
106  }),
107  py::arg("level"));
108  }
109 #endif
110 
111 #ifdef ACTS_EXATRKX_ONNX_BACKEND
112  {
113  using Alg = Acts::OnnxMetricLearning;
114  using Config = Alg::Config;
115 
116  auto alg =
117  py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
118  mex, "OnnxMetricLearning")
119  .def(py::init([](const Config &c, Logging::Level lvl) {
120  return std::make_shared<Alg>(
121  c, getDefaultLogger("MetricLearning", lvl));
122  }),
123  py::arg("config"), py::arg("level"))
124  .def_property_readonly("config", &Alg::config);
125 
126  auto c = py::class_<Config>(alg, "Config").def(py::init<>());
128  ACTS_PYTHON_MEMBER(modelPath);
129  ACTS_PYTHON_MEMBER(spacepointFeatures);
130  ACTS_PYTHON_MEMBER(embeddingDim);
131  ACTS_PYTHON_MEMBER(rVal);
132  ACTS_PYTHON_MEMBER(knnVal);
134  }
135  {
136  using Alg = Acts::OnnxEdgeClassifier;
137  using Config = Alg::Config;
138 
139  auto alg =
140  py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
141  mex, "OnnxEdgeClassifier")
142  .def(py::init([](const Config &c, Logging::Level lvl) {
143  return std::make_shared<Alg>(
144  c, getDefaultLogger("EdgeClassifier", lvl));
145  }),
146  py::arg("config"), py::arg("level"))
147  .def_property_readonly("config", &Alg::config);
148 
149  auto c = py::class_<Config>(alg, "Config").def(py::init<>());
151  ACTS_PYTHON_MEMBER(modelPath);
152  ACTS_PYTHON_MEMBER(cut);
154  }
155  {
156  using Alg = Acts::CugraphTrackBuilding;
157 
158  auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
159  mex, "CugraphTrackBuilding")
160  .def(py::init([](Logging::Level lvl) {
161  return std::make_shared<Alg>(
162  getDefaultLogger("EdgeClassifier", lvl));
163  }),
164  py::arg("level"));
165  }
166 #endif
167 
170  "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
171  inputParticles, inputClusters, inputMeasurementSimhitsMap,
172  outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder,
173  rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale,
174  clusterYScale, targetMinHits, targetMinPT);
175 
176  {
177  auto cls =
178  py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
179  mex, "ExaTrkXHook");
180  }
181 
182  {
183  using Class = Acts::TorchTruthGraphMetricsHook;
184 
185  auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
186  mex, "TorchTruthGraphMetricsHook")
187  .def(py::init(
188  [](const std::vector<int64_t> &g, Logging::Level lvl) {
189  return std::make_shared<Class>(
190  g, getDefaultLogger("PipelineHook", lvl));
191  }));
192  }
193 
194  {
195  using Class = Acts::ExaTrkXPipeline;
196 
197  auto cls =
198  py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
199  .def(py::init(
200  [](std::shared_ptr<GraphConstructionBase> g,
201  std::vector<std::shared_ptr<EdgeClassificationBase>> e,
202  std::shared_ptr<TrackBuildingBase> t,
203  Logging::Level lvl) {
204  return std::make_shared<Class>(
205  g, e, t, getDefaultLogger("MetricLearning", lvl));
206  }),
207  py::arg("graphConstructor"), py::arg("edgeClassifiers"),
208  py::arg("trackBuilder"), py::arg("level"))
209  .def("run", &ExaTrkXPipeline::run, py::arg("features"),
210  py::arg("spacepoints"), py::arg("deviceHint") = -1,
211  py::arg("hook") = Acts::ExaTrkXHook{},
212  py::arg("timing") = nullptr);
213  }
214 
216  ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
217  inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
218  outputProtoTracks, geometry, magneticField, buildTightSeeds);
219 
222  "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
223  inputMeasurements, inputSourceLinks, inputInitialTrackParameters,
224  outputTracks, measurementSelectorCfg, trackingGeometry, magneticField,
225  findTracks, tag);
226 }
227 
228 } // namespace Acts::Python