Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchEdgeClassifier.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file TorchEdgeClassifier.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
11 #include <torch/script.h>
12 #include <torch/torch.h>
13 
14 #include "printCudaMemInfo.hpp"
15 
16 using namespace torch::indexing;
17 
18 namespace Acts {
19 
20 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
21  std::unique_ptr<const Logger> _logger)
22  : m_logger(std::move(_logger)), m_cfg(cfg) {
23  c10::InferenceMode guard(true);
24  m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
25  ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
26  << TORCH_VERSION_MINOR << "."
27  << TORCH_VERSION_PATCH);
28 #ifndef ACTS_EXATRKX_CPUONLY
29  if (not torch::cuda::is_available()) {
30  ACTS_INFO("CUDA not available, falling back to CPU");
31  }
32 #endif
33 
34  try {
35  m_model = std::make_unique<torch::jit::Module>();
36  *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_deviceType);
37  m_model->eval();
38  } catch (const c10::Error& e) {
39  throw std::invalid_argument("Failed to load models: " + e.msg());
40  }
41 }
42 
44 
45 std::tuple<std::any, std::any, std::any> TorchEdgeClassifier::operator()(
46  std::any inputNodes, std::any inputEdges, int deviceHint) {
47  ACTS_DEBUG("Start edge classification");
48  c10::InferenceMode guard(true);
49  const torch::Device device(m_deviceType, deviceHint);
50 
51  auto nodes = std::any_cast<torch::Tensor>(inputNodes).to(device);
52  auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(device);
53 
54  if (m_cfg.numFeatures > nodes.size(1)) {
55  throw std::runtime_error("requested more features then available");
56  }
57 
58  std::vector<at::Tensor> results;
59  results.reserve(m_cfg.nChunks);
60 
61  auto edgeListTmp =
62  m_cfg.undirected ? torch::cat({edgeList, edgeList.flip(0)}, 1) : edgeList;
63 
64  std::vector<torch::jit::IValue> inputTensors(2);
65  inputTensors[0] = m_cfg.numFeatures < nodes.size(1)
66  ? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}})
67  : nodes;
68 
69  const auto chunks = at::chunk(at::arange(edgeListTmp.size(1)), m_cfg.nChunks);
70  for (const auto& chunk : chunks) {
71  ACTS_VERBOSE("Process chunk");
72  inputTensors[1] = edgeListTmp.index({Slice(), chunk});
73 
74  results.push_back(m_model->forward(inputTensors).toTensor());
75  results.back().squeeze_();
76  results.back().sigmoid_();
77  }
78 
79  auto output = torch::cat(results);
80 
81  if (m_cfg.undirected) {
82  output = output.index({Slice(None, output.size(0) / 2)});
83  }
84 
85  ACTS_VERBOSE("Size after classifier: " << output.size(0));
86  ACTS_VERBOSE("Slice of classified output:\n"
87  << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
88  printCudaMemInfo(logger());
89 
90  torch::Tensor mask = output > m_cfg.cut;
91  torch::Tensor edgesAfterCut = edgeList.index({Slice(), mask});
92  edgesAfterCut = edgesAfterCut.to(torch::kInt64);
93 
94  ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
95  printCudaMemInfo(logger());
96 
97  return {std::move(nodes), std::move(edgesAfterCut),
98  output.masked_select(mask)};
99 }
100 
101 } // namespace Acts