14 #include <torch/script.h>
15 #include <torch/torch.h>
19 using namespace torch::indexing;
23 TorchMetricLearning::TorchMetricLearning(
const Config &
cfg,
24 std::unique_ptr<const Logger> _logger)
26 c10::InferenceMode guard(
true);
27 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
28 ACTS_DEBUG(
"Using torch version " << TORCH_VERSION_MAJOR <<
"."
29 << TORCH_VERSION_MINOR <<
"."
30 << TORCH_VERSION_PATCH);
31 #ifndef ACTS_EXATRKX_CPUONLY
32 if (not torch::cuda::is_available()) {
33 ACTS_INFO(
"CUDA not available, falling back to CPU");
38 m_model = std::make_unique<torch::jit::Module>();
41 }
catch (
const c10::Error &
e) {
42 throw std::invalid_argument(
"Failed to load models: " + e.msg());
49 std::vector<float> &inputValues, std::size_t numNodes,
int deviceHint) {
51 c10::InferenceMode guard(
true);
54 const int64_t numAllFeatures = inputValues.size() / numNodes;
59 for (
int i = 0;
i < numAllFeatures; ++
i) {
60 ss << inputValues[
i] <<
" ";
64 printCudaMemInfo(
logger());
70 if (inputTensor.options().device() == device) {
71 inputTensor = inputTensor.clone();
73 inputTensor = inputTensor.to(device);
81 throw std::runtime_error(
"requested more features then available");
88 std::vector<torch::jit::IValue> inputTensors;
89 inputTensors.push_back(
91 ? inputTensor.index({Slice{}, Slice{None,
m_cfg.numFeatures}})
95 << inputTensors[0].toTensor().
size(0) <<
", "
96 << inputTensors[0].toTensor().
size(1));
98 auto output = model.forward(inputTensors).toTensor();
101 <<
output.slice(0, 0, 1));
102 printCudaMemInfo(
logger());
109 m_cfg.shuffleDirections);
111 ACTS_VERBOSE(
"Shape of built edges: (" << edgeList.size(0) <<
", "
112 << edgeList.size(1));
113 ACTS_VERBOSE(
"Slice of edgelist:\n" << edgeList.slice(1, 0, 5));
114 printCudaMemInfo(
logger());