Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LinearTransform.hpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file LinearTransform.hpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2020-2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #pragma once
10 
11 // SYCL plugin include(s).
13 
14 #include "../Utilities/Arrays.hpp"
15 #include "SpacePointType.hpp"
16 
17 // SYCL include(s).
18 #include <CL/sycl.hpp>
19 
20 // System include(s).
21 #include <cassert>
22 #include <cstdint>
23 
24 // VecMem include(s).
25 #include "vecmem/containers/data/vector_view.hpp"
26 #include "vecmem/containers/device_vector.hpp"
27 
28 namespace Acts::Sycl::detail {
29 
31 template <SpacePointType OtherSPType>
33  // Sanity check(s).
34  static_assert((OtherSPType == SpacePointType::Bottom) ||
35  (OtherSPType == SpacePointType::Top),
36  "Class must be instantiated with either "
37  "Acts::Sycl::detail::SpacePointType::Bottom or "
38  "Acts::Sycl::detail::SpacePointType::Top");
39 
40  public:
43  vecmem::data::vector_view<const DeviceSpacePoint> middleSPs,
44  vecmem::data::vector_view<const DeviceSpacePoint> otherSPs,
45  vecmem::data::vector_view<uint32_t> middleIndexLUT,
46  vecmem::data::vector_view<uint32_t> otherIndexLUT, uint32_t nEdges,
47  vecmem::data::vector_view<detail::DeviceLinEqCircle> resultArray)
48  : m_middleSPs(middleSPs),
49  m_otherSPs(otherSPs),
50  m_middleIndexLUT(middleIndexLUT),
51  m_otherIndexLUT(otherIndexLUT),
52  m_nEdges(nEdges),
53  m_resultArray(resultArray) {}
54 
56  void operator()(cl::sycl::nd_item<1> item) const {
57  // Get the index to operate on.
58  const auto idx = item.get_global_linear_id();
59  if (idx >= m_nEdges) {
60  return;
61  }
62 
63  // Translate this one index into indices in the spacepoint arrays.
64  // Note that using asserts with the CUDA backend of dpc++ is not working
65  // quite correctly at the moment. :-( So these checks may need to be
66  // disabled if you need to build for an NVidia backend in Debug mode.
67  vecmem::device_vector<uint32_t> middleIndexLUT(m_middleIndexLUT);
68  const uint32_t middleIndex = middleIndexLUT[idx];
69  assert(middleIndex < m_middleSPs.size());
70  (void)m_middleSPs.size();
71  vecmem::device_vector<uint32_t> otherIndexLUT(m_otherIndexLUT);
72  const uint32_t otherIndex = otherIndexLUT[idx];
73  assert(otherIndex < m_otherSPs.size());
74  (void)m_otherSPs.size();
75 
76  // Create a copy of the spacepoint objects for the current thread. On
77  // dedicated GPUs this provides a better performance than accessing
78  // variables one-by-one from global device memory.
79  const vecmem::device_vector<const DeviceSpacePoint> middleSPs(m_middleSPs);
80  const DeviceSpacePoint middleSP = middleSPs[middleIndex];
81  const vecmem::device_vector<const DeviceSpacePoint> otherSPs(m_otherSPs);
82  const DeviceSpacePoint otherSP = otherSPs[otherIndex];
83 
84  // Calculate some "helper variables" for the coordinate linear
85  // transformation.
86  const float cosPhiM = middleSP.x / middleSP.r;
87  const float sinPhiM = middleSP.y / middleSP.r;
88 
89  const float deltaX = otherSP.x - middleSP.x;
90  const float deltaY = otherSP.y - middleSP.y;
91  const float deltaZ = otherSP.z - middleSP.z;
92 
93  const float x = deltaX * cosPhiM + deltaY * sinPhiM;
94  const float y = deltaY * cosPhiM - deltaX * sinPhiM;
95  const float iDeltaR2 = 1.f / (deltaX * deltaX + deltaY * deltaY);
96 
97  // Create the result object.
98  DeviceLinEqCircle result;
99  result.iDeltaR = cl::sycl::sqrt(iDeltaR2);
100  result.cotTheta = deltaZ * result.iDeltaR;
101  if constexpr (OtherSPType == SpacePointType::Bottom) {
102  result.cotTheta = -(result.cotTheta);
103  }
104  result.zo = middleSP.z - middleSP.r * result.cotTheta;
105  result.u = x * iDeltaR2;
106  result.v = y * iDeltaR2;
107  result.er =
108  ((middleSP.varZ + otherSP.varZ) +
109  (result.cotTheta * result.cotTheta) * (middleSP.varR + otherSP.varR)) *
110  iDeltaR2;
111 
112  // Store the result in the result vector
113  vecmem::device_vector<detail::DeviceLinEqCircle> resultArray(m_resultArray);
114  resultArray[idx] = result;
115  return;
116  }
117 
118  private:
120  vecmem::data::vector_view<const DeviceSpacePoint> m_middleSPs;
122  vecmem::data::vector_view<const DeviceSpacePoint> m_otherSPs;
123 
125  vecmem::data::vector_view<uint32_t> m_middleIndexLUT;
127  vecmem::data::vector_view<uint32_t> m_otherIndexLUT;
128 
130  uint32_t m_nEdges;
131 
133  vecmem::data::vector_view<detail::DeviceLinEqCircle> m_resultArray;
134 
135 }; // class LinearTransform
136 
137 } // namespace Acts::Sycl::detail