24 #include <pybind11/functional.h>
25 #include <pybind11/pybind11.h>
26 #include <pybind11/stl.h>
28 namespace py = pybind11;
30 using namespace ActsExamples;
33 namespace Acts::Python {
36 auto [
m, mex] = ctx.
get(
"main",
"examples");
40 auto c = py::class_<C, std::shared_ptr<C>>(mex,
"GraphConstructionBase");
44 auto c = py::class_<C, std::shared_ptr<C>>(mex,
"EdgeClassificationBase");
48 auto c = py::class_<C, std::shared_ptr<C>>(mex,
"TrackBuildingBase");
51 #ifdef ACTS_EXATRKX_TORCH_BACKEND
57 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
58 mex,
"TorchMetricLearning")
60 return std::make_shared<Alg>(
63 py::arg(
"config"), py::arg(
"level"))
66 auto c = py::class_<Config>(alg,
"Config").def(py::init<>());
80 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
81 mex,
"TorchEdgeClassifier")
83 return std::make_shared<Alg>(
86 py::arg(
"config"), py::arg(
"level"))
89 auto c = py::class_<Config>(alg,
"Config").def(py::init<>());
101 auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
102 mex,
"BoostTrackBuilding")
104 return std::make_shared<Alg>(
111 #ifdef ACTS_EXATRKX_ONNX_BACKEND
117 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
118 mex,
"OnnxMetricLearning")
120 return std::make_shared<Alg>(
123 py::arg(
"config"), py::arg(
"level"))
126 auto c = py::class_<Config>(alg,
"Config").def(py::init<>());
140 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
141 mex,
"OnnxEdgeClassifier")
143 return std::make_shared<Alg>(
146 py::arg(
"config"), py::arg(
"level"))
149 auto c = py::class_<Config>(alg,
"Config").def(py::init<>());
158 auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
159 mex,
"CugraphTrackBuilding")
161 return std::make_shared<Alg>(
170 "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
172 outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder,
173 rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale,
174 clusterYScale, targetMinHits, targetMinPT);
178 py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
185 auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
186 mex,
"TorchTruthGraphMetricsHook")
189 return std::make_shared<Class>(
198 py::class_<Class, std::shared_ptr<Class>>(mex,
"ExaTrkXPipeline")
200 [](std::shared_ptr<GraphConstructionBase>
g,
201 std::vector<std::shared_ptr<EdgeClassificationBase>>
e,
202 std::shared_ptr<TrackBuildingBase>
t,
204 return std::make_shared<Class>(
207 py::arg(
"graphConstructor"), py::arg(
"edgeClassifiers"),
208 py::arg(
"trackBuilder"), py::arg(
"level"))
210 py::arg(
"spacepoints"), py::arg(
"deviceHint") = -1,
212 py::arg(
"timing") =
nullptr);
217 inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
222 "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
223 inputMeasurements, inputSourceLinks, inputInitialTrackParameters,