11 #ifdef ACTS_PLUGIN_ONNX
48 using namespace Acts::UnitLiterals;
49 using namespace ActsExamples;
50 using namespace std::filesystem;
51 using namespace std::placeholders;
54 using namespace ActsExamples;
55 using boost::program_options::bool_switch;
57 auto opt = desc.add_options();
58 opt(
"ckf-truth-smeared-seeds", bool_switch(),
59 "Use track parameters smeared from truth particles for steering CKF");
60 opt(
"ckf-truth-estimated-seeds", bool_switch(),
61 "Use track parameters estimated from truth tracks for steering CKF");
65 int argc,
char* argv[],
66 const std::shared_ptr<ActsExamples::IBaseDetector>&
detector) {
75 OutputFormat::Csv | OutputFormat::DirectoryOnly);
76 detector->addOptions(desc);
97 auto rnd = std::make_shared<ActsExamples::RandomNumbers>(
101 vm[
"ckf-truth-estimated-seeds"].template as<bool>();
107 for (
const auto& cdr : geometry.second) {
120 simHitReaderCfg.outputSimHits);
127 particleSelectorCfg.
inputParticles = particleReader.outputParticles;
129 digiCfg.outputMeasurementParticlesMap;
131 particleSelectorCfg.
ptMin = 500_MeV;
134 std::make_shared<TruthSeedSelector>(particleSelectorCfg,
logLevel));
142 if (truthSmearedSeeded) {
144 auto particleSmearingCfg =
146 outputTrackParameters = particleSmearingCfg.outputTrackParameters;
160 if (truthEstimatedSeeded) {
165 digiCfg.outputMeasurementParticlesMap;
168 std::make_shared<TruthTrackFinder>(trackFinderCfg,
logLevel));
218 std::make_shared<SeedingAlgorithm>(seedingCfg,
logLevel));
224 std::make_shared<SeedsToPrototracks>(seedsToPrototrackCfg,
logLevel));
236 digiCfg.outputMeasurementParticlesMap;
237 tfPerfCfg.
filePath = outputDir +
"/performance_seeding_trees.root";
239 std::make_shared<TrackFinderPerformanceWriter>(tfPerfCfg,
logLevel));
248 paramsEstimationCfg.
initialSigmas = {25._um, 100._um, 0.02_degree,
249 0.02_degree, 0.1 / 1._GeV, 1400._s};
251 vm[
"ckf-initial-variance-inflation"].template as<Options::Reals<6>>();
253 sequencer.
addAlgorithm(std::make_shared<TrackParamsEstimationAlgorithm>(
263 trackFindingCfg.inputMeasurements = digiCfg.outputMeasurements;
264 trackFindingCfg.inputSourceLinks = digiCfg.outputSourceLinks;
265 trackFindingCfg.inputInitialTrackParameters = outputTrackParameters;
266 trackFindingCfg.outputTracks =
"tracks";
267 trackFindingCfg.computeSharedHits =
true;
268 trackFindingCfg.findTracks = TrackFindingAlgorithm::makeTrackFinderFunction(
272 std::make_shared<TrackFindingAlgorithm>(trackFindingCfg,
logLevel));
275 tracksToTrajCfg.
inputTracks = trackFindingCfg.outputTracks;
276 tracksToTrajCfg.outputTrajectories =
"trajectories";
278 (std::make_shared<TracksToTrajectories>(tracksToTrajCfg,
logLevel)));
287 trackStatesWriter.
inputParticles = particleReader.outputParticles;
288 trackStatesWriter.
inputSimHits = simHitReaderCfg.outputSimHits;
290 digiCfg.outputMeasurementParticlesMap;
292 digiCfg.outputMeasurementSimHitsMap;
293 trackStatesWriter.
filePath = outputDir +
"/trackstates_ckf.root";
294 trackStatesWriter.
treeName =
"trackstates";
295 sequencer.
addWriter(std::make_shared<RootTrajectoryStatesWriter>(
305 trackSummaryWriter.
inputParticles = particleReader.outputParticles;
307 digiCfg.outputMeasurementParticlesMap;
308 trackSummaryWriter.
filePath = outputDir +
"/tracksummary_ckf.root";
309 trackSummaryWriter.
treeName =
"tracksummary";
310 sequencer.
addWriter(std::make_shared<RootTrajectorySummaryWriter>(
318 digiCfg.outputMeasurementParticlesMap;
319 perfWriterCfg.
filePath = outputDir +
"/performance_ckf.root";
320 #ifdef ACTS_PLUGIN_ONNX
323 path currentFilePath(__FILE__);
324 path parentPath = currentFilePath.parent_path();
326 canonical(parentPath /
"MLAmbiguityResolutionDemo.onnx").native();
328 double decisionThreshProb = 0.5;
330 Ort::Env
env(ORT_LOGGING_LEVEL_WARNING,
"MLTrackClassifier");
334 std::placeholders::_1, decisionThreshProb);
337 std::make_shared<CKFPerformanceWriter>(perfWriterCfg,
logLevel));
339 if (vm[
"output-csv"].
template as<bool>()) {
345 digiCfg.outputMeasurementParticlesMap;
346 sequencer.
addWriter(std::make_shared<CsvMultiTrajectoryWriter>(
350 return sequencer.
run();