11 #include <onnxruntime_cxx_api.h>
12 #include <torch/script.h>
16 using namespace torch::indexing;
20 OnnxEdgeClassifier::OnnxEdgeClassifier(
const Config &
cfg,
21 std::unique_ptr<const Logger>
logger)
23 m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
24 "ExaTrkX - edge classifier");
26 Ort::SessionOptions session_options;
27 session_options.SetIntraOpNumThreads(1);
28 session_options.SetGraphOptimizationLevel(
29 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
34 Ort::AllocatorWithDefaultOptions allocator;
47 std::any inputNodes, std::any inputEdges,
int) {
48 Ort::AllocatorWithDefaultOptions allocator;
49 auto memoryInfo = Ort::MemoryInfo::CreateCpu(
50 OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
52 auto eInputTensor = std::any_cast<std::shared_ptr<Ort::Value>>(inputNodes);
53 auto edgeList = std::any_cast<std::vector<int64_t>>(inputEdges);
54 const int numEdges = edgeList.size() / 2;
58 std::vector<Ort::Value> fInputTensor;
59 fInputTensor.push_back(
std::move(*eInputTensor));
60 std::vector<int64_t> fEdgeShape{2, numEdges};
61 fInputTensor.push_back(Ort::Value::CreateTensor<int64_t>(
62 memoryInfo, edgeList.data(), edgeList.size(), fEdgeShape.data(),
67 std::vector<float> fOutputData(numEdges);
69 auto outputDims =
m_model->GetOutputTypeInfo(0)
70 .GetTensorTypeAndShapeInfo()
71 .GetDimensionsCount();
72 using Shape = std::vector<int64_t>;
73 Shape fOutputShape = outputDims == 2 ? Shape{numEdges, 1} : Shape{numEdges};
74 std::vector<Ort::Value> fOutputTensor;
75 fOutputTensor.push_back(Ort::Value::CreateTensor<float>(
76 memoryInfo, fOutputData.data(), fOutputData.size(), fOutputShape.data(),
77 fOutputShape.size()));
81 ACTS_DEBUG(
"Get scores for " << numEdges <<
" edges.");
82 torch::Tensor edgeListCTen = torch::tensor(edgeList, {torch::kInt64});
83 edgeListCTen = edgeListCTen.reshape({2, numEdges});
85 torch::Tensor fOutputCTen = torch::tensor(fOutputData, {torch::kFloat32});
86 fOutputCTen = fOutputCTen.sigmoid();
88 torch::Tensor filterMask = fOutputCTen >
m_cfg.
cut;
89 torch::Tensor edgesAfterFCTen = edgeListCTen.index({Slice(), filterMask});
91 std::vector<int64_t> edgesAfterFiltering;
92 std::copy(edgesAfterFCTen.data_ptr<int64_t>(),
93 edgesAfterFCTen.data_ptr<int64_t>() + edgesAfterFCTen.numel(),
94 std::back_inserter(edgesAfterFiltering));
96 int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2;
97 ACTS_DEBUG(
"Finished edge classification, after cut: " << numEdgesAfterF
100 return {std::make_shared<Ort::Value>(
std::move(fInputTensor[0])),
101 edgesAfterFiltering, fOutputCTen};