Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OnnxRuntimeBase.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file OnnxRuntimeBase.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2020 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
10 
11 #include <cassert>
12 #include <stdexcept>
13 
14 // Parametrized constructor
15 Acts::OnnxRuntimeBase::OnnxRuntimeBase(Ort::Env& env, const char* modelPath) {
16  // Set the ONNX runtime session options
17  Ort::SessionOptions sessionOptions;
18  // Set graph optimization level
19  sessionOptions.SetGraphOptimizationLevel(
20  GraphOptimizationLevel::ORT_ENABLE_BASIC);
21  // Create the Ort session
22  m_session = std::make_unique<Ort::Session>(env, modelPath, sessionOptions);
23  // Default allocator
24  Ort::AllocatorWithDefaultOptions allocator;
25 
26  // Get the names of the input nodes of the model
27  size_t numInputNodes = m_session->GetInputCount();
28  // Iterate over all input nodes and get the name
29  for (size_t i = 0; i < numInputNodes; i++) {
30  m_inputNodeNamesAllocated.push_back(
31  m_session->GetInputNameAllocated(i, allocator));
32  m_inputNodeNames.push_back(m_inputNodeNamesAllocated.back().get());
33 
34  // Get the dimensions of the input nodes
35  // Assumes single input
36  Ort::TypeInfo inputTypeInfo = m_session->GetInputTypeInfo(i);
37  auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
38  m_inputNodeDims = tensorInfo.GetShape();
39  }
40 
41  // Get the names of the output nodes
42  size_t numOutputNodes = m_session->GetOutputCount();
43  // Iterate over all output nodes and get the name
44  for (size_t i = 0; i < numOutputNodes; i++) {
46  m_session->GetOutputNameAllocated(i, allocator));
47  m_outputNodeNames.push_back(m_outputNodeNamesAllocated.back().get());
48 
49  // Get the dimensions of the output nodes
50  Ort::TypeInfo outputTypeInfo = m_session->GetOutputTypeInfo(i);
51  auto tensorInfo = outputTypeInfo.GetTensorTypeAndShapeInfo();
52  m_outputNodeDims.push_back(tensorInfo.GetShape());
53  }
54 }
55 
56 // Inference function using ONNX runtime for one single entry
58  std::vector<float>& inputTensorValues) const {
59  Acts::NetworkBatchInput vectorInput(1, inputTensorValues.size());
60  for (size_t i = 0; i < inputTensorValues.size(); i++) {
61  vectorInput(0, i) = inputTensorValues[i];
62  }
63  auto vectorOutput = runONNXInference(vectorInput);
64  return vectorOutput[0];
65 }
66 
67 // Inference function using ONNX runtime
68 // the function assumes that the model has 1 input node and 1 output node
69 std::vector<std::vector<float>> Acts::OnnxRuntimeBase::runONNXInference(
70  Acts::NetworkBatchInput& inputTensorValues) const {
71  return runONNXInferenceMultiOutput(inputTensorValues).front();
72 }
73 
74 // Inference function for single-input, multi-output models
75 std::vector<std::vector<std::vector<float>>>
77  NetworkBatchInput& inputTensorValues) const {
78  int batchSize = inputTensorValues.rows();
79  std::vector<int64_t> inputNodeDims = m_inputNodeDims;
80  std::vector<std::vector<int64_t>> outputNodeDims = m_outputNodeDims;
81 
82  // The first dim node should correspond to the batch size
83  // If it is -1, it is dynamic and should be set to the input size
84  if (inputNodeDims[0] == -1) {
85  inputNodeDims[0] = batchSize;
86  }
87 
88  bool outputDimsMatch = true;
89  for (std::vector<int64_t>& nodeDim : outputNodeDims) {
90  if (nodeDim[0] == -1) {
91  nodeDim[0] = batchSize;
92  }
93  outputDimsMatch &= batchSize == 1 || nodeDim[0] == batchSize;
94  }
95 
96  if (batchSize != 1 && (inputNodeDims[0] != batchSize || !outputDimsMatch)) {
97  throw std::runtime_error(
98  "runONNXInference: batch size doesn't match the input or output node "
99  "size");
100  }
101 
102  // Create input tensor object from data values
103  Ort::MemoryInfo memoryInfo =
104  Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
105  Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
106  memoryInfo, inputTensorValues.data(), inputTensorValues.size(),
107  inputNodeDims.data(), inputNodeDims.size());
108  // Double-check that inputTensor is a Tensor
109  if (!inputTensor.IsTensor()) {
110  throw std::runtime_error(
111  "runONNXInference: conversion of input to Tensor failed. ");
112  }
113  // Score model on input tensors, get back output tensors
114  Ort::RunOptions run_options;
115  std::vector<Ort::Value> outputTensors =
116  m_session->Run(run_options, m_inputNodeNames.data(), &inputTensor,
117  m_inputNodeNames.size(), m_outputNodeNames.data(),
118  m_outputNodeNames.size());
119 
120  // Double-check that outputTensors contains Tensors and that the count matches
121  // that of output nodes
122  if (!outputTensors[0].IsTensor() ||
123  (outputTensors.size() != m_outputNodeNames.size())) {
124  throw std::runtime_error(
125  "runONNXInference: calculation of output failed. ");
126  }
127 
128  std::vector<std::vector<std::vector<float>>> multiOutput;
129 
130  for (size_t i_out = 0; i_out < outputTensors.size(); i_out++) {
131  // Get pointer to output tensor float values
132  float* outputTensor = outputTensors.at(i_out).GetTensorMutableData<float>();
133  // Get the output values
134  std::vector<std::vector<float>> outputTensorValues(
135  batchSize, std::vector<float>(outputNodeDims.at(i_out)[1], -1));
136  for (int i = 0; i < outputNodeDims.at(i_out)[0]; i++) {
137  for (int j = 0; j < ((outputNodeDims.at(i_out).size() > 1)
138  ? outputNodeDims.at(i_out)[1]
139  : 1);
140  j++) {
141  outputTensorValues[i][j] =
142  outputTensor[i * outputNodeDims.at(i_out)[1] + j];
143  }
144  }
145  multiOutput.push_back(std::move(outputTensorValues));
146  }
147  return multiOutput;
148 }