Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchTruthGraphMetricsHook.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file TorchTruthGraphMetricsHook.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
12 
13 #include <torch/torch.h>
14 
15 namespace {
16 
17 auto cantorize(std::vector<int64_t> edgeIndex, const Acts::Logger& logger) {
18  // Use cantor pairing to store truth graph, so we can easily use set
19  // operations to compute efficiency and purity
20  std::vector<Acts::detail::CantorEdge<int64_t>> cantorEdgeIndex;
21  cantorEdgeIndex.reserve(edgeIndex.size() / 2);
22  for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
23  cantorEdgeIndex.emplace_back(*it, *std::next(it));
24  }
25 
26  std::sort(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
27 
28  auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
29  if (new_end != cantorEdgeIndex.end()) {
30  ACTS_WARNING("Graph not unique ("
31  << std::distance(new_end, cantorEdgeIndex.end())
32  << " duplicates)");
33  cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
34  }
35 
36  return cantorEdgeIndex;
37 }
38 
39 } // namespace
40 
42  const std::vector<int64_t>& truthGraph,
43  std::unique_ptr<const Acts::Logger> l)
44  : m_logger(std::move(l)) {
45  m_truthGraphCantor = cantorize(truthGraph, logger());
46 }
47 
49  const std::any& edges) const {
50  // We need to transpose the edges here for the right memory layout
51  const auto edgeIndex = Acts::detail::tensor2DToVector<int64_t>(
52  std::any_cast<torch::Tensor>(edges).t());
53 
54  auto predGraphCantor = cantorize(edgeIndex, logger());
55 
56  // Calculate intersection
57  std::vector<Acts::detail::CantorEdge<int64_t>> intersection;
58  intersection.reserve(
59  std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
60 
61  std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
62  m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
63  std::back_inserter(intersection));
64 
65  ACTS_DEBUG("Intersection size " << intersection.size());
66  const float intersectionSizeFloat = intersection.size();
67  const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
68  const float pur = intersectionSizeFloat / predGraphCantor.size();
69 
70  ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
71 }