22 using namespace ActsExamples;
23 using namespace Acts::UnitLiterals;
28 double m_targetPT = 0.5_GeV;
29 std::size_t m_targetSize = 3;
31 std::unique_ptr<const Acts::Logger>
m_logger;
32 std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
33 std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
38 std::size_t spacePointIndex;
47 std::size_t targetMinHits,
double targetMinPT,
49 : m_targetPT(targetMinPT),
50 m_targetSize(targetMinHits),
51 m_logger(logger.clone(
"MetricsHook")) {
53 std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>>
tracks;
55 for (
auto i = 0ul;
i < spacepoints.size(); ++
i) {
56 const auto measId = spacepoints[
i]
58 .template get<IndexSourceLink>()
61 auto [
a,
b] = measHitMap.equal_range(measId);
63 const auto& hit = *simhits.nth(
it->second);
65 tracks[hit.particleId()].push_back({
i, hit.index()});
70 std::vector<int64_t> truthGraph;
71 std::vector<int64_t> targetGraph;
73 for (
auto& [
pid, track] : tracks) {
75 std::sort(track.begin(), track.end(), [](
const auto&
a,
const auto&
b) {
76 return a.hitIndex <
b.hitIndex;
79 auto found = particles.find(
pid);
80 if (found == particles.end()) {
85 for (
auto i = 0ul;
i < track.size() - 1; ++
i) {
86 truthGraph.push_back(track[
i].spacePointIndex);
87 truthGraph.push_back(track[
i + 1].spacePointIndex);
89 if (found->transverseMomentum() > m_targetPT &&
90 track.size() >= m_targetSize) {
91 targetGraph.push_back(track[
i].spacePointIndex);
92 targetGraph.push_back(track[
i + 1].spacePointIndex);
97 m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
98 truthGraph, logger.
clone());
99 m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
100 targetGraph, logger.
clone());
103 ~ExamplesEdmHook() {}
105 void operator()(
const std::any& nodes,
const std::any&
edges)
const override {
107 (*m_truthGraphHook)(nodes,
edges);
108 ACTS_INFO(
"Metrics for target graph (pT > "
110 <<
" GeV, nHits >= " << m_targetSize <<
"):");
111 (*m_targetGraphHook)(nodes,
edges);
119 : ActsExamples::
IAlgorithm(
"TrackFindingMLBasedAlgorithm", level),
121 m_pipeline(
m_cfg.graphConstructor,
m_cfg.edgeClassifiers,
122 m_cfg.trackBuilder, logger().clone()) {
124 throw std::invalid_argument(
"Missing spacepoint input collection");
127 throw std::invalid_argument(
"Missing protoTrack output collection");
137 if(
m_cfg.sanitize ) {
138 Eigen::VectorXf dummyInput = Eigen::VectorXf::Random(3 * 15);
139 std::vector<float> dummyInputVec(dummyInput.data(),
140 dummyInput.data() + dummyInput.size());
141 std::vector<int> spacepointIDs;
142 std::iota(spacepointIDs.begin(), spacepointIDs.end(), 0);
144 runPipeline(dummyInputVec, spacepointIDs);
176 auto spacepoints = m_inputSpacePoints(ctx);
178 auto hook = std::make_unique<Acts::ExaTrkXHook>();
179 if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
180 hook = std::make_unique<ExamplesEdmHook>(
181 spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
182 m_inputParticles(ctx),
m_cfg.targetMinHits,
m_cfg.targetMinPT,
186 std::optional<ClusterContainer>
clusters;
187 if (m_inputClusters.isInitialized()) {
188 clusters = m_inputClusters(ctx);
193 const std::size_t numSpacepoints = spacepoints.size();
194 const std::size_t numFeatures = clusters ? 7 : 3;
195 ACTS_INFO(
"Received " << numSpacepoints <<
" spacepoints");
197 std::vector<float> features(numSpacepoints * numFeatures);
198 std::vector<int> spacepointIDs;
200 spacepointIDs.reserve(spacepoints.size());
202 double sumCells = 0.0;
203 double sumActivation = 0.0;
205 for (
auto i = 0ul;
i < numSpacepoints; ++
i) {
206 const auto& sp = spacepoints[
i];
209 float* featurePtr = features.data() +
i * numFeatures;
213 const auto& sl = sp.sourceLinks()[0].template get<IndexSourceLink>();
214 spacepointIDs.push_back(sl.index());
216 featurePtr[
eR] = std::hypot(sp.x(), sp.y()) /
m_cfg.rScale;
217 featurePtr[
ePhi] = std::atan2(sp.y(), sp.x()) /
m_cfg.phiScale;
218 featurePtr[
eZ] = sp.z() /
m_cfg.zScale;
221 const auto& cluster = clusters->at(sl.index());
222 const auto& chnls = cluster.channels;
224 featurePtr[
eCellCount] = cluster.channels.size() /
m_cfg.cellCountScale;
226 std::accumulate(chnls.begin(), chnls.end(), 0.0,
228 return s +
c.activation;
235 sumActivation += featurePtr[
eCellSum];
239 ACTS_DEBUG(
"Avg cell count: " << sumCells / spacepoints.size());
240 ACTS_DEBUG(
"Avg activation: " << sumActivation / sumCells);
243 const auto trackCandidates = [&]() {
244 const int deviceHint = -1;
245 std::lock_guard<std::mutex> lock(m_mutex);
249 m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
254 for (
auto [aggr,
a] :
264 ACTS_DEBUG(
"Done with pipeline, received " << trackCandidates.size()
268 std::vector<ProtoTrack> protoTracks;
269 protoTracks.reserve(trackCandidates.size());
270 for (
auto&
x : trackCandidates) {
272 std::copy(
x.begin(),
x.end(), std::back_inserter(onetrack));
273 protoTracks.push_back(
std::move(onetrack));
276 ACTS_INFO(
"Created " << protoTracks.size() <<
" proto tracks");
277 m_outputProtoTracks(ctx,
std::move(protoTracks));
283 namespace ba = boost::accumulators;
287 const auto&
t =
m_timing.graphBuildingTime;
289 << std::sqrt(ba::variance(
t)));
291 for (
const auto&
t :
m_timing.classifierTimes) {
293 << std::sqrt(ba::variance(
t)));
296 const auto&
t =
m_timing.trackBuildingTime;
298 << std::sqrt(ba::variance(
t)));