Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
AutodiffExtensionWrapper.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file AutodiffExtensionWrapper.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2020 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
15 
16 #include <autodiff/forward/dual.hpp>
17 #include <autodiff/forward/dual/eigen.hpp>
18 
19 namespace Acts {
20 
22 template <template <typename> typename basic_extension_t>
25  AutodiffExtensionWrapper() = default;
26 
27  // Some typedefs
28  using AutodiffScalar = autodiff::dual;
29  using AutodiffVector3 = Eigen::Matrix<AutodiffScalar, 3, 1>;
30  using AutodiffFreeVector = Eigen::Matrix<AutodiffScalar, eFreeSize, 1>;
31  using AutodiffFreeMatrix =
32  Eigen::Matrix<AutodiffScalar, eFreeSize, eFreeSize>;
33 
34  // The double-extension is needed to communicate with the "outer world" (the
35  // stepper) and ensures it behaves exactly as the underlying extension, with
36  // the exception of the computation of the transport-matrix. The corresponding
37  // autodiff-extension can be found in the RKN4step-member-function (since it
38  // is only needed locally). Another advantage of this approach is, that we do
39  // not differentiate through the adaptive stepsize estimation in the stepper.
40  basic_extension_t<double> m_doubleExtension;
41 
42  // Just call underlying extension
43  template <typename propagator_state_t, typename stepper_t,
44  typename navigator_t>
45  int bid(const propagator_state_t& ps, const stepper_t& st,
46  const navigator_t& na) const {
47  return m_doubleExtension.bid(ps, st, na);
48  }
49 
50  // Just call underlying extension
51  template <typename propagator_state_t, typename stepper_t,
52  typename navigator_t>
53  bool k(const propagator_state_t& state, const stepper_t& stepper,
54  const navigator_t& navigator, Vector3& knew, const Vector3& bField,
55  std::array<double, 4>& kQoP, const int i = 0, const double h = 0.,
56  const Vector3& kprev = Vector3::Zero()) {
57  return m_doubleExtension.k(state, stepper, navigator, knew, bField, kQoP, i,
58  h, kprev);
59  }
60 
61  // Just call underlying extension
62  template <typename propagator_state_t, typename stepper_t,
63  typename navigator_t>
64  bool finalize(propagator_state_t& state, const stepper_t& stepper,
65  const navigator_t& navigator, const double h) const {
66  return m_doubleExtension.finalize(state, stepper, navigator, h);
67  }
68 
69  // Here we call a custom implementation to compute the transport matrix
70  template <typename propagator_state_t, typename stepper_t,
71  typename navigator_t>
72  bool finalize(propagator_state_t& state, const stepper_t& stepper,
73  const navigator_t& navigator, const double h,
74  FreeMatrix& D) const {
75 #if defined(__GNUC__) && __GNUC__ == 12 && !defined(__clang__)
76 #pragma GCC diagnostic push
77 #pragma GCC diagnostic ignored "-Wuse-after-free"
78 #endif
79  m_doubleExtension.finalize(state, stepper, navigator, h);
80  return transportMatrix(state, stepper, navigator, h, D);
81 #if defined(__GNUC__) && __GNUC__ == 12 && !defined(__clang__)
82 #pragma GCC diagnostic pop
83 #endif
84  }
85 
86  private:
87  // A fake stepper-state
89  // dummy defaults which will/should be overwritten
93  bool covTransport = false;
94  };
95 
96  // A fake propagator state
97  template <class options_t, class navigation_t>
98  struct FakePropState {
100  const options_t& options;
101  const navigation_t& navigation;
102  };
103 
104  // A fake stepper
105  struct FakeStepper {
106  auto position(const FakeStepperState& s) const {
107  return s.pars.template segment<3>(eFreePos0);
108  }
109  auto direction(const FakeStepperState& s) const {
110  return s.pars.template segment<3>(eFreeDir0);
111  }
112  auto qOverP(const FakeStepperState& s) const { return s.pars(eFreeQOverP); }
113  auto absoluteMomentum(const FakeStepperState& s) const {
114  return particleHypothesis(s).extractMomentum(qOverP(s));
115  }
116  auto charge(const FakeStepperState& s) const {
117  return particleHypothesis(s).extractCharge(qOverP(s));
118  }
119  auto particleHypothesis(const FakeStepperState& s) const {
120  return s.particleHypothesis;
121  }
122  };
123 
124  // Here the autodiff jacobian is computed
125  template <typename propagator_state_t, typename stepper_t,
126  typename navigator_t>
127  bool transportMatrix(propagator_state_t& state, const stepper_t& stepper,
128  const navigator_t& navigator, const double h,
129  FreeMatrix& D) const {
130  // Initialize fake stepper
131  using ThisFakePropState =
132  FakePropState<decltype(state.options), decltype(state.navigation)>;
133 
134  ThisFakePropState fstate{FakeStepperState(), state.options,
135  state.navigation};
136 
137  fstate.stepping.particleHypothesis =
138  stepper.particleHypothesis(state.stepping);
139 
140  // Init dependent values for autodiff
141  AutodiffFreeVector initial_params;
142  initial_params.segment<3>(eFreePos0) = stepper.position(state.stepping);
143  initial_params(eFreeTime) = stepper.time(state.stepping);
144  initial_params.segment<3>(eFreeDir0) = stepper.direction(state.stepping);
145  initial_params(eFreeQOverP) = stepper.qOverP(state.stepping);
146 
147  const auto& sd = state.stepping.stepData;
148 
149  // Compute jacobian
150  D = jacobian(
151  [&](const auto& in) {
152  return RKN4step(in, sd, fstate, navigator, h);
153  },
154  wrt(initial_params), at(initial_params))
155  .template cast<double>();
156 
157  return true;
158  }
159 
160  template <typename step_data_t, typename fake_state_t, typename navigator_t>
161  auto RKN4step(const AutodiffFreeVector& in, const step_data_t& sd,
162  fake_state_t state, const navigator_t& navigator,
163  const double h) const {
164  // Initialize fake stepper
166 
167  // Set dependent variables
168  state.stepping.pars = in;
169 
170  std::array<AutodiffScalar, 4> kQoP;
171  std::array<AutodiffVector3, 4> k;
172 
173  // Autodiff instance of the extension
174  basic_extension_t<AutodiffScalar> ext;
175 
176  // Compute k values. Assume all return true, since these parameters
177  // are already validated by the "outer RKN4"
178  ext.k(state, stepper, navigator, k[0], sd.B_first, kQoP);
179  ext.k(state, stepper, navigator, k[1], sd.B_middle, kQoP, 1, h * 0.5, k[0]);
180  ext.k(state, stepper, navigator, k[2], sd.B_middle, kQoP, 2, h * 0.5, k[1]);
181  ext.k(state, stepper, navigator, k[3], sd.B_last, kQoP, 3, h, k[2]);
182 
183  // finalize
184  ext.finalize(state, stepper, navigator, h);
185 
186  // Compute RKN4 integration
188 
189  // position
190  out.segment<3>(eFreePos0) = in.segment<3>(eFreePos0) +
191  h * in.segment<3>(eFreeDir0) +
192  h * h / 6. * (k[0] + k[1] + k[2]);
193 
194  // direction
195  auto final_dir =
196  in.segment<3>(eFreeDir0) + h / 6. * (k[0] + 2. * (k[1] + k[2]) + k[3]);
197 
198  out.segment<3>(eFreeDir0) = final_dir / final_dir.norm();
199 
200  // qop
201  out(eFreeQOverP) = state.stepping.pars(eFreeQOverP);
202 
203  // time
204  out(eFreeTime) = state.stepping.pars(eFreeTime);
205 
206  return out;
207  }
208 };
209 
210 } // namespace Acts