11 #include <torch/script.h>
12 #include <torch/torch.h>
16 using namespace torch::indexing;
20 TorchEdgeClassifier::TorchEdgeClassifier(
const Config&
cfg,
21 std::unique_ptr<const Logger> _logger)
23 c10::InferenceMode guard(
true);
24 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
25 ACTS_DEBUG(
"Using torch version " << TORCH_VERSION_MAJOR <<
"."
26 << TORCH_VERSION_MINOR <<
"."
27 << TORCH_VERSION_PATCH);
28 #ifndef ACTS_EXATRKX_CPUONLY
29 if (not torch::cuda::is_available()) {
30 ACTS_INFO(
"CUDA not available, falling back to CPU");
35 m_model = std::make_unique<torch::jit::Module>();
38 }
catch (
const c10::Error&
e) {
39 throw std::invalid_argument(
"Failed to load models: " + e.msg());
46 std::any inputNodes, std::any inputEdges,
int deviceHint) {
48 c10::InferenceMode guard(
true);
51 auto nodes = std::any_cast<torch::Tensor>(inputNodes).to(device);
52 auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(device);
55 throw std::runtime_error(
"requested more features then available");
58 std::vector<at::Tensor>
results;
62 m_cfg.
undirected ? torch::cat({edgeList, edgeList.flip(0)}, 1) : edgeList;
64 std::vector<torch::jit::IValue> inputTensors(2);
69 const auto chunks = at::chunk(at::arange(edgeListTmp.size(1)),
m_cfg.
nChunks);
70 for (
const auto& chunk : chunks) {
72 inputTensors[1] = edgeListTmp.index({Slice(), chunk});
74 results.push_back(
m_model->forward(inputTensors).toTensor());
75 results.back().squeeze_();
76 results.back().sigmoid_();
79 auto output = torch::cat(results);
88 printCudaMemInfo(
logger());
91 torch::Tensor edgesAfterCut = edgeList.index({Slice(), mask});
92 edgesAfterCut = edgesAfterCut.to(torch::kInt64);
94 ACTS_VERBOSE(
"Size after score cut: " << edgesAfterCut.size(1));
95 printCudaMemInfo(
logger());
98 output.masked_select(mask)};