Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
AlgebraHelpers.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file AlgebraHelpers.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2016-2023 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
12 
13 #include <bitset>
14 #include <optional>
15 
16 #include "Eigen/Dense"
17 
18 namespace Acts {
19 
27 template <typename MatrixType>
28 MatrixType bitsetToMatrix(const std::bitset<MatrixType::RowsAtCompileTime *
29  MatrixType::ColsAtCompileTime>
30  bs) {
31  constexpr int rows = MatrixType::RowsAtCompileTime;
32  constexpr int cols = MatrixType::ColsAtCompileTime;
33 
34  static_assert(rows != -1 && cols != -1,
35  "bitsetToMatrix does not support dynamic matrices");
36 
37  MatrixType m;
38  auto* p = m.data();
39  for (size_t i = 0; i < rows * cols; i++) {
40  p[i] = bs[rows * cols - 1 - i];
41  }
42  return m;
43 }
44 
51 template <typename Derived>
52 auto matrixToBitset(const Eigen::PlainObjectBase<Derived>& m) {
53  using MatrixType = Eigen::PlainObjectBase<Derived>;
54  constexpr size_t rows = MatrixType::RowsAtCompileTime;
55  constexpr size_t cols = MatrixType::ColsAtCompileTime;
56 
57  std::bitset<rows * cols> res;
58 
59  auto* p = m.data();
60  for (size_t i = 0; i < rows * cols; i++) {
61  res[rows * cols - 1 - i] = static_cast<bool>(p[i]);
62  }
63 
64  return res;
65 }
66 
77 template <typename A, typename B>
79  const A& a, const B& b) {
80  // Extract the sizes of the matrix types that we receive as template
81  // parameters.
82  constexpr int M = A::RowsAtCompileTime;
83  constexpr int N = A::ColsAtCompileTime;
84  constexpr int P = B::ColsAtCompileTime;
85 
86  // Ensure that the second dimension of our first matrix equals the first
87  // dimension of the second matrix, otherwise we cannot multiply.
88  static_assert(N == B::RowsAtCompileTime);
89 
90  if constexpr (M <= 4 && N <= 4 && P <= 4) {
91  // In cases where the matrices being multiplied are small, we can rely on
92  // Eigen do to a good job, and we don't really need to do any blocking.
93  return a * b;
94  } else {
95  // Here, we want to calculate the expression: C = AB, Eigen, natively,
96  // doesn't do a great job at this if the matrices A and B are large
97  // (roughly M >= 8, N >= 8, or P >= 8), and applies a slow GEMM operation.
98  // We apply a blocked matrix multiplication operation to decompose the
99  // multiplication into smaller operations, relying on the fact that:
100  //
101  // ┌ ┐ ┌ ┐ ┌ ┐
102  // │ C₁₁ C₁₂ │ = │ A₁₁ A₁₂ │ │ B₁₁ B₁₂ │
103  // │ C₂₁ C₂₂ │ = │ A₂₁ A₂₂ │ │ B₂₁ B₂₂ │
104  // └ ┘ └ ┘ └ ┘
105  //
106  // where:
107  //
108  // C₁₁ = A₁₁ * B₁₁ + A₁₂ * B₂₁
109  // C₁₂ = A₁₁ * B₁₂ + A₁₂ * B₂₂
110  // C₂₁ = A₂₁ * B₁₁ + A₂₂ * B₂₁
111  // C₂₂ = A₂₁ * B₁₂ + A₂₂ * B₂₂
112  //
113  // The sizes of these submatrices are roughly half (in each dimension) that
114  // of the parent matrix. If the size of the parent matrix is even, we can
115  // divide it exactly, If the size of the parent matrix is odd, then some
116  // of the submatrices will be one larger than the others. In general, for
117  // any matrix Q, the sizes of the submatrices are (where / denotes integer
118  // division):
119  //
120  // Q₁₁ : M / 2 × P / 2
121  // Q₁₂ : M / 2 × (P + 1) / 2
122  // Q₂₁ : (M + 1) / 2 × P / 2
123  // Q₂₂ : (M + 1) / 2 × (P + 1) / 2
124  //
125  // See https://csapp.cs.cmu.edu/public/waside/waside-blocking.pdf for a
126  // more in-depth explanation of blocked matrix multiplication.
127  constexpr int M1 = M / 2;
128  constexpr int M2 = (M + 1) / 2;
129  constexpr int N1 = N / 2;
130  constexpr int N2 = (N + 1) / 2;
131  constexpr int P1 = P / 2;
132  constexpr int P2 = (P + 1) / 2;
133 
134  // Construct the end result in this matrix, which destroys a few of Eigen's
135  // built-in optimization techniques, but sadly this is necessary.
137 
138  // C₁₁ = A₁₁ * B₁₁ + A₁₂ * B₂₁
139  r.template topLeftCorner<M1, P1>().noalias() =
140  a.template topLeftCorner<M1, N1>() *
141  b.template topLeftCorner<N1, P1>() +
142  a.template topRightCorner<M1, N2>() *
143  b.template bottomLeftCorner<N2, P1>();
144 
145  // C₁₂ = A₁₁ * B₁₂ + A₁₂ * B₂₂
146  r.template topRightCorner<M1, P2>().noalias() =
147  a.template topLeftCorner<M1, N1>() *
148  b.template topRightCorner<N1, P2>() +
149  a.template topRightCorner<M1, N2>() *
150  b.template bottomRightCorner<N2, P2>();
151 
152  // C₂₁ = A₂₁ * B₁₁ + A₂₂ * B₂₁
153  r.template bottomLeftCorner<M2, P1>().noalias() =
154  a.template bottomLeftCorner<M2, N1>() *
155  b.template topLeftCorner<N1, P1>() +
156  a.template bottomRightCorner<M2, N2>() *
157  b.template bottomLeftCorner<N2, P1>();
158 
159  // C₂₂ = A₂₁ * B₁₂ + A₂₂ * B₂₂
160  r.template bottomRightCorner<M2, P2>().noalias() =
161  a.template bottomLeftCorner<M2, N1>() *
162  b.template topRightCorner<N1, P2>() +
163  a.template bottomRightCorner<M2, N2>() *
164  b.template bottomRightCorner<N2, P2>();
165 
166  return r;
167  }
168 }
169 
177 
187 template <typename MatrixType, typename ResultType = MatrixType>
188 std::optional<ResultType> safeInverse(const MatrixType& m) noexcept {
189  ResultType result;
190  bool invertible = false;
191 
192  m.computeInverseWithCheck(result, invertible);
193 
194  if (invertible) {
195  return result;
196  }
197 
198  return std::nullopt;
199 }
200 
204 template <typename T>
205 struct ExpSafeLimit {};
206 template <>
208  constexpr static double value = 500.0;
209 };
210 template <>
211 struct ExpSafeLimit<float> {
212  constexpr static float value = 50.0;
213 };
214 
221 template <typename T>
222 constexpr T safeExp(T val) noexcept {
223  constexpr T maxExponent = ExpSafeLimit<T>::value;
224  constexpr T minExponent = -maxExponent;
225  if (val < minExponent) {
226  return 0.0;
227  }
228 
229  if (val > maxExponent) {
230  return std::numeric_limits<T>::infinity();
231  }
232 
233  return std::exp(val);
234 }
235 
236 } // namespace Acts