Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxlib.cc
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file onnxlib.cc
1 #include "onnxlib.h"
2 
3 #include <iostream>
4 
5 // --------------------------------------------------
6 Ort::Session *onnxSession(std::string &modelfile)
7 {
8  Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "fit");
9  Ort::SessionOptions sessionOptions;
10  sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
11 
12  return new Ort::Session(env, modelfile.c_str(), sessionOptions);
13 }
14 
15 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nsamp, int Nreturn)
16 {
17  Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
18 
19  Ort::AllocatorWithDefaultOptions allocator;
20 
21  std::vector<Ort::Value> inputTensors, outputTensors;
22 
23  std::vector<int64_t> inputDimsN = {N, Nsamp};
24  std::vector<int64_t> outputDimsN = {N, Nreturn};
25 
26  std::vector<float> outputTensorValuesN(N * Nreturn);
27 
28  inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), N * Nsamp, inputDimsN.data(), inputDimsN.size()));
29  outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValuesN.data(), N * Nreturn, outputDimsN.data(), outputDimsN.size()));
30 
31  std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
32  std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
33 
34  session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
35 
36  return outputTensorValuesN;
37 }