Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CreateSeedsForGroupSycl.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file CreateSeedsForGroupSycl.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2020-2021 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 // System include(s)
10 #include <algorithm>
11 #include <cstdint>
12 #include <cstring>
13 #include <exception>
14 #include <functional>
15 #include <memory>
16 #include <vector>
17 
18 // Acts include(s)
20 
21 // SYCL plugin include(s)
25 
26 #include "../Utilities/Arrays.hpp"
27 #include "DupletSearch.hpp"
28 #include "LinearTransform.hpp"
29 #include "TripletFilter.hpp"
30 #include "TripletSearch.hpp"
31 
32 // VecMem include(s).
33 #include "vecmem/containers/data/jagged_vector_buffer.hpp"
34 #include "vecmem/containers/data/vector_buffer.hpp"
35 #include "vecmem/utils/sycl/copy.hpp"
36 
37 // SYCL include
38 #include <CL/sycl.hpp>
39 
40 namespace Acts::Sycl {
41 // Kernel classes in order of execution.
42 class ind_copy_bottom_kernel;
43 class ind_copy_top_kernel;
44 class triplet_search_kernel;
45 class filter_2sp_fixed_kernel;
46 
48  QueueWrapper wrappedQueue, vecmem::memory_resource& resource,
49  vecmem::memory_resource* device_resource,
50  const detail::DeviceSeedFinderConfig& seedFinderConfig,
51  const DeviceExperimentCuts& deviceCuts,
52  vecmem::vector<detail::DeviceSpacePoint>& bottomSPs,
53  vecmem::vector<detail::DeviceSpacePoint>& middleSPs,
54  vecmem::vector<detail::DeviceSpacePoint>& topSPs,
55  std::vector<std::vector<detail::SeedData>>& seeds) {
56  // Each vector stores data of space points in simplified
57  // structures of float variables
58  // M: number of middle space points
59  // B: number of bottom space points
60  // T: number of top space points
61  const uint32_t M = middleSPs.size();
62  const uint32_t B = bottomSPs.size();
63  const uint32_t T = topSPs.size();
64 
65  // Up to the Nth space point, the sum of compatible bottom/top space points.
66  // We need these for indexing other vectors later in the algorithm.
67  // These are prefix sum arrays, with a leading zero.
68  // Those will be created with either host or shared memory resource
69  vecmem::vector<uint32_t> sumBotMidPrefix(&resource);
70  sumBotMidPrefix.push_back(0);
71  vecmem::vector<uint32_t> sumTopMidPrefix(&resource);
72  sumTopMidPrefix.push_back(0);
73  vecmem::vector<uint32_t> sumBotTopCombPrefix(&resource);
74  sumBotTopCombPrefix.push_back(0);
75 
76  // After completing the duplet search, we'll have successfully constructed
77  // two bipartite graphs for bottom-middle and top-middle space points.
78  // We store the indices of the middle space points of the corresponding
79  // edges.
80  vecmem::vector<uint32_t> indMidBotComp(&resource);
81  vecmem::vector<uint32_t> indMidTopComp(&resource);
82 
83  try {
84  auto* q = wrappedQueue.getQueue();
85  uint64_t globalBufferSize =
86  q->get_device().get_info<cl::sycl::info::device::global_mem_size>();
87  uint64_t maxWorkGroupSize =
88  q->get_device().get_info<cl::sycl::info::device::max_work_group_size>();
89  vecmem::sycl::copy copy(wrappedQueue.getQueue());
90 
91  // Calculate 2 dimensional range of bottom-middle duplet search kernel
92  // We'll have a total of M*B threads globally, but we need to give the
93  // nd_range the global dimensions so that they are an exact multiple of
94  // the local dimensions. That's why we need this calculation.
95 
96  cl::sycl::nd_range<2> bottomDupletNDRange =
97  calculate2DimNDRange(M, B, maxWorkGroupSize);
98  cl::sycl::nd_range<2> topDupletNDRange =
99  calculate2DimNDRange(M, T, maxWorkGroupSize);
100 
101  // Create views of the space point vectors.
102  // They will be constructed differently depending on the number of memory
103  // resources given.
104  std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceSpacePoint>>
105  deviceBottomSPs, deviceTopSPs, deviceMiddleSPs;
106  vecmem::data::vector_view<detail::DeviceSpacePoint> bottomSPsView,
107  topSPsView, middleSPsView;
108  if (!device_resource) {
109  bottomSPsView = vecmem::get_data(bottomSPs);
110  topSPsView = vecmem::get_data(topSPs);
111  middleSPsView = vecmem::get_data(middleSPs);
112  } else {
113  deviceBottomSPs = std::make_unique<
114  vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
115  B, *device_resource);
116  deviceTopSPs = std::make_unique<
117  vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
118  T, *device_resource);
119  deviceMiddleSPs = std::make_unique<
120  vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
121  M, *device_resource);
122 
123  copy(vecmem::get_data(bottomSPs), *deviceBottomSPs);
124  copy(vecmem::get_data(topSPs), *deviceTopSPs);
125  copy(vecmem::get_data(middleSPs), *deviceMiddleSPs);
126 
127  bottomSPsView = vecmem::get_data(*deviceBottomSPs);
128  topSPsView = vecmem::get_data(*deviceTopSPs);
129  middleSPsView = vecmem::get_data(*deviceMiddleSPs);
130  }
131  //*********************************************//
132  // ********** DUPLET SEARCH - BEGIN ********** //
133  //*********************************************//
134 
135  // Create the output data of the duplet search - jagged vectors.
136  std::unique_ptr<vecmem::data::jagged_vector_buffer<uint32_t>>
137  midBotDupletBuffer;
138  std::unique_ptr<vecmem::data::jagged_vector_buffer<uint32_t>>
139  midTopDupletBuffer;
140 
141  midBotDupletBuffer =
142  std::make_unique<vecmem::data::jagged_vector_buffer<uint32_t>>(
143  std::vector<std::size_t>(M, 0), std::vector<std::size_t>(M, B),
144  (device_resource ? *device_resource : resource),
145  (device_resource ? &resource : nullptr));
146  midTopDupletBuffer =
147  std::make_unique<vecmem::data::jagged_vector_buffer<uint32_t>>(
148  std::vector<std::size_t>(M, 0), std::vector<std::size_t>(M, T),
149  (device_resource ? *device_resource : resource),
150  (device_resource ? &resource : nullptr));
151  copy.setup(*midBotDupletBuffer);
152  copy.setup(*midTopDupletBuffer);
153 
154  // Perform the middle-bottom duplet search.
155  auto middleBottomEvent = q->submit([&](cl::sycl::handler& h) {
157  middleSPsView, bottomSPsView, *midBotDupletBuffer, seedFinderConfig);
158  h.parallel_for<class DupletSearchBottomKernel>(bottomDupletNDRange,
159  kernel);
160  });
161 
162  // Perform the middle-top duplet search.
163  auto middleTopEvent = q->submit([&](cl::sycl::handler& h) {
165  middleSPsView, topSPsView, *midTopDupletBuffer, seedFinderConfig);
166  h.parallel_for<class DupletSearchTopKernel>(topDupletNDRange, kernel);
167  });
168  middleBottomEvent.wait_and_throw();
169  middleTopEvent.wait_and_throw();
170  //*********************************************//
171  // *********** DUPLET SEARCH - END *********** //
172  //*********************************************//
173 
174  // Get the sizes of the inner vectors of the jagged vector - number of
175  // compatible bottom/top SPs for each MiddleSP.
176  auto countBotDuplets = copy.get_sizes(*midBotDupletBuffer);
177  auto countTopDuplets = copy.get_sizes(*midTopDupletBuffer);
178  // Construct prefix sum arrays of duplet counts.
179  // These will later be used to index other arrays based on middle SP
180  // indices.
181  for (uint32_t i = 1; i < M + 1; ++i) {
182  sumBotMidPrefix.push_back(sumBotMidPrefix.at(i - 1) +
183  countBotDuplets[i - 1]);
184  sumTopMidPrefix.push_back(sumTopMidPrefix.at(i - 1) +
185  countTopDuplets[i - 1]);
186  sumBotTopCombPrefix.push_back(sumBotTopCombPrefix.at(i - 1) +
187  countBotDuplets[i - 1] *
188  countTopDuplets[i - 1]);
189  }
190  // Number of edges for middle-bottom and middle-top duplet bipartite graphs.
191  const uint64_t edgesBottom = sumBotMidPrefix[M];
192  const uint64_t edgesTop = sumTopMidPrefix[M];
193  // Number of possible compatible triplets. This is the sum of the
194  // combination of the number of compatible bottom and compatible top duplets
195  // per middle space point. (nb0*nt0 + nb1*nt1 + ... where nbk is the number
196  // of comp. bot. SPs for the kth middle SP)
197  const uint64_t edgesComb = sumBotTopCombPrefix[M];
198 
199  indMidBotComp.reserve(edgesBottom);
200  indMidTopComp.reserve(edgesTop);
201 
202  // Fill arrays of middle SP indices of found duplets (bottom and top).
203  for (uint32_t mid = 0; mid < M; ++mid) {
204  std::fill_n(std::back_inserter(indMidBotComp), countBotDuplets[mid], mid);
205  std::fill_n(std::back_inserter(indMidTopComp), countTopDuplets[mid], mid);
206  }
207 
208  if (edgesBottom > 0 && edgesTop > 0) {
209  // Calculate global and local range of execution for edgesBottom number of
210  // threads. Local range is the same as block size in CUDA.
211  cl::sycl::nd_range<1> edgesBotNdRange =
212  calculate1DimNDRange(edgesBottom, maxWorkGroupSize);
213 
214  // Global and local range of execution for edgesTop number of threads.
215  cl::sycl::nd_range<1> edgesTopNdRange =
216  calculate1DimNDRange(edgesTop, maxWorkGroupSize);
217 
218  // EXPLANATION OF INDEXING (first part)
219  /*
220  (for bottom-middle duplets, but it is the same for middle-tops)
221  In case we have 4 middle SP and 5 bottom SP, our temporary array of
222  the compatible bottom duplet indices would look like this:
223  ---------------------
224  mid0 | 0 | 3 | 4 | 1 | - | Indices in the columns correspond to
225  mid1 | 3 | 2 | - | - | - | bottom SP indices in the bottomSPs
226  mid2 | - | - | - | - | - | array. Threads are executed concurrently,
227  mid3 | 4 | 2 | 1 | - | - | so the order of indices is random.
228  ---------------------
229  We will refer to this structure as a bipartite graph, as it can be
230  described by a graph of nodes for middle and bottom SPs, and edges
231  between one middle and one bottom SP, but never two middle or two
232  bottom SPs.
233  We will flatten this matrix out, and store the indices the
234  following way (this is indBotDupletBuffer):
235  -------------------------------------
236  | 0 | 3 | 4 | 1 | 3 | 2 | 4 | 2 | 1 |
237  -------------------------------------
238  Also the length of this array is equal to edgesBottom, which is 9 in
239  this example. It is the number of the edges of the bottom-middle
240  bipartite graph.
241  To find out where the indices of bottom SPs start for a particular
242  middle SP, we use prefix sum arrays.
243  We know how many duplets were found for each middle SP (this is
244  countBotDuplets).
245  -----------------
246  | 4 | 2 | 0 | 3 |
247  -----------------
248  We will make a prefix sum array of these counts, with a leading zero:
249  (this is sumBotMidPrefix)
250  ---------------------
251  | 0 | 4 | 6 | 6 | 9 |
252  ---------------------
253  If we have the middle SP with index 1, then we know that the indices
254  of the compatible bottom SPs are in the range (left closed, right
255  open) [sumBotMidPrefix[1] , sumBotMidPrefix[2] ) of indBotDUpletBuffer.
256  In this case, these indices are 3 and 2, so we'd use these to index
257  views of bottomSPs to gather data about the bottom SP.
258  To be able to get the indices of middle SPs in constant time inside
259  kernels, we will also prepare arrays that store the indices of the
260  middleSPs of the edges (indMidBotComp).
261  -------------------------------------
262  | 0 | 0 | 0 | 0 | 1 | 1 | 3 | 3 | 3 |
263  -------------------------------------
264  (For the same purpose, we could also do a binary search on the
265  sumBotMidPrefix array, and we will do exactly that later, in the triplet
266  search kernel.)
267  We will execute the coordinate transformation on edgesBottom threads,
268  or 9 in our example.
269  The size of the array storing our transformed coordinates
270  (linearBotBuffer) is also edgesBottom, the sum of bottom duplets we
271  found so far.
272  */
273 
274  // We store the indices of the BOTTOM/TOP space points of the edges of
275  // the bottom-middle and top-middle bipartite duplet graphs. They index
276  // the bottomSPs and topSPs vectors.
277 
278  // We store the indices of the MIDDLE space points of the edges of the
279  // bottom-middle and top-middle bipartite duplet graphs.
280  // They index the middleSP vector.
281  // indMidBotComp;
282  // indMidTopComp;
283 
284  // Partial sum arrays of deviceNumBot and deviceNum
285  // Partial sum array of the combinations of compatible bottom and top
286  // space points per middle space point.
287  // Allocations for coordinate transformation.
288 
289  // Buffers for Flattening the jagged vectors
290  std::unique_ptr<vecmem::data::vector_buffer<uint32_t>> indBotDupletBuffer;
291  std::unique_ptr<vecmem::data::vector_buffer<uint32_t>> indTopDupletBuffer;
292 
293  indBotDupletBuffer =
294  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
295  edgesBottom, (device_resource ? *device_resource : resource));
296  indTopDupletBuffer =
297  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
298  edgesTop, (device_resource ? *device_resource : resource));
299 
300  copy.setup(*indBotDupletBuffer);
301  copy.setup(*indTopDupletBuffer);
302 
303  // Pointers constructed in case the device memory resource was given.
304  std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
305  device_sumBotMidPrefix, device_sumTopMidPrefix,
306  device_sumBotTopCombPrefix;
307  // Vecmem views of the prefix sums used throughout the later code.
308  vecmem::data::vector_view<uint32_t> sumBotMidView, sumTopMidView,
309  sumBotTopCombView;
310 
311  // Same behaviour for the vectors of indices
312  std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
313  device_indMidBotComp, device_indMidTopComp;
314  vecmem::data::vector_view<uint32_t> indMidBotCompView, indMidTopCompView;
315  // Copy indices from temporary matrices to final, optimal size vectors.
316  // We will use these for easier indexing.
317  if (!device_resource) {
318  sumBotMidView = vecmem::get_data(sumBotMidPrefix);
319  sumTopMidView = vecmem::get_data(sumTopMidPrefix);
320  sumBotTopCombView = vecmem::get_data(sumBotTopCombPrefix);
321 
322  indMidBotCompView = vecmem::get_data(indMidBotComp);
323  indMidTopCompView = vecmem::get_data(indMidTopComp);
324  } else {
325  device_sumBotMidPrefix =
326  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
327  M + 1, *device_resource);
328  device_sumTopMidPrefix =
329  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
330  M + 1, *device_resource);
331  device_sumBotTopCombPrefix =
332  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
333  M + 1, *device_resource);
334 
335  device_indMidBotComp =
336  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
337  edgesBottom, *device_resource);
338  device_indMidTopComp =
339  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
340  edgesTop, *device_resource);
341 
342  copy(vecmem::get_data(sumBotMidPrefix), *device_sumBotMidPrefix);
343  copy(vecmem::get_data(sumTopMidPrefix), *device_sumTopMidPrefix);
344  copy(vecmem::get_data(sumBotTopCombPrefix),
345  *device_sumBotTopCombPrefix);
346 
347  copy(vecmem::get_data(indMidBotComp), *device_indMidBotComp);
348  copy(vecmem::get_data(indMidTopComp), *device_indMidTopComp);
349 
350  sumBotMidView = vecmem::get_data(*device_sumBotMidPrefix);
351  sumTopMidView = vecmem::get_data(*device_sumTopMidPrefix);
352  sumBotTopCombView = vecmem::get_data(*device_sumBotTopCombPrefix);
353 
354  indMidBotCompView = vecmem::get_data(*device_indMidBotComp);
355  indMidTopCompView = vecmem::get_data(*device_indMidTopComp);
356  }
357  auto midBotDupletView = vecmem::get_data(*midBotDupletBuffer);
358  auto indBotDupletView = vecmem::get_data(*indBotDupletBuffer);
359  auto indBotEvent = q->submit([&](cl::sycl::handler& h) {
360  h.parallel_for<ind_copy_bottom_kernel>(
361  edgesBotNdRange, [=](cl::sycl::nd_item<1> item) {
362  auto idx = item.get_global_linear_id();
363  if (idx < edgesBottom) {
364  vecmem::device_vector<uint32_t> deviceIndMidBot(
365  indMidBotCompView),
366  sumBotMidPrefix(sumBotMidView),
367  indBotDuplets(indBotDupletView);
368  vecmem::jagged_device_vector<const uint32_t> midBotDuplets(
369  midBotDupletView);
370  auto mid = deviceIndMidBot[idx];
371  auto ind = midBotDuplets[mid][idx - sumBotMidPrefix[mid]];
372  indBotDuplets[idx] = ind;
373  }
374  });
375  });
376  auto midTopDupletView = vecmem::get_data(*midTopDupletBuffer);
377  auto indTopDupletView = vecmem::get_data(*indTopDupletBuffer);
378  auto indTopEvent = q->submit([&](cl::sycl::handler& h) {
379  h.parallel_for<ind_copy_top_kernel>(
380  edgesTopNdRange, [=](cl::sycl::nd_item<1> item) {
381  auto idx = item.get_global_linear_id();
382  if (idx < edgesTop) {
383  vecmem::device_vector<uint32_t> deviceIndMidTop(
384  indMidTopCompView),
385  sumTopMidPrefix(sumTopMidView),
386  indTopDuplets(indTopDupletView);
387  vecmem::jagged_device_vector<const uint32_t> midTopDuplets(
388  midTopDupletView);
389  auto mid = deviceIndMidTop[idx];
390  auto ind = midTopDuplets[mid][idx - sumTopMidPrefix[mid]];
391  indTopDuplets[idx] = ind;
392  }
393  });
394  });
395  indBotEvent.wait_and_throw();
396  indTopEvent.wait_and_throw();
397 
398  // Create the output data of the linear transform
399  std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>
400  linearBotBuffer;
401  std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>
402  linearTopBuffer;
403 
404  linearBotBuffer = std::make_unique<
405  vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>(
406  edgesBottom, (device_resource ? *device_resource : resource));
407  linearTopBuffer = std::make_unique<
408  vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>(
409  edgesTop, (device_resource ? *device_resource : resource));
410 
411  copy.setup(*linearBotBuffer);
412  copy.setup(*linearTopBuffer);
413 
414  //************************************************//
415  // *** LINEAR EQUATION TRANSFORMATION - BEGIN *** //
416  //************************************************//
417 
418  // transformation of circle equation (x,y) into linear equation (u,v)
419  // x^2 + y^2 - 2x_0*x - 2y_0*y = 0
420  // is transformed into
421  // 1 - 2x_0*u - 2y_0*v = 0
422 
423  // coordinate transformation middle-bottom pairs
424  auto linB = q->submit([&](cl::sycl::handler& h) {
426  middleSPsView, bottomSPsView, indMidBotCompView,
427  *indBotDupletBuffer, edgesBottom, *linearBotBuffer);
428  h.parallel_for<class TransformCoordBottomKernel>(edgesBotNdRange,
429  kernel);
430  });
431 
432  // coordinate transformation middle-top pairs
433  auto linT = q->submit([&](cl::sycl::handler& h) {
435  middleSPsView, topSPsView, indMidTopCompView, *indTopDupletBuffer,
436  edgesTop, *linearTopBuffer);
437  h.parallel_for<class TransformCoordTopKernel>(edgesTopNdRange, kernel);
438  });
439 
440  //************************************************//
441  // **** LINEAR EQUATION TRANSFORMATION - END **** //
442  //************************************************//
443 
444  //************************************************//
445  // *********** TRIPLET SEARCH - BEGIN *********** //
446  //************************************************//
447 
448  // EXPLANATION OF INDEXING (second part)
449  /*
450  For the triplet search, we calculate the upper limit of constructible
451  triplets.
452 
453  For this, we multiply the number of compatible bottom and compatible
454  top SPs for each middle SP, and add these together.
455  (nb0*nt0 + nb1*nt1 + ... where nbk is the number of compatible bottom
456  SPs for the kth middle SP, similarly ntb is for tops)
457 
458  sumBotTopCombPrefix is a prefix sum array (of length M+1) of the
459  calculated combinations.
460 
461  sumBotTopCombPrefix:
462  ________________________________________________________
463  | | | | | M | M = number
464  | 0 | nb0*nt0 | nb0*nt0 + nb1*nt1 | ... | ∑ nbi+nti | of middle
465  |_____|_________|___________________|_____|_i=0________| space points
466 
467  We will start kernels and reserve memory for these combinations but
468  only so much we can fit into memory at once.
469 
470  We limit our memory usage to globalBufferSize/2, this is currently
471  hard-coded, but it could be configured. Actually, it would be better
472  to use a separate object that manages memory allocations and
473  deallocations and we could ask it to lend us as much memory as it is
474  happy to give.
475 
476  For later, let maxMemoryAllocation be maximum allocatable memory for
477  triplet search.
478 
479  We start by adding up summing the combinations, until we arrive at a
480  k which for:
481 
482  k+1
483  ∑ nbi+nti > maxMemoryAllocation
484  i=0
485  (or k == M).
486 
487  So we know, that we need to start our first kernel for the first k
488  middle SPs.
489 
490  Inside the triplet search kernel we start with a binary search, to
491  find out which middle SP the thread corresponds to. Note, that
492  sumBotTopCombPrefix is a monotone increasing series of values which
493  allows us to do a binary search on it.
494 
495  Inside the triplet search kernel we count the triplets for fixed
496  bottom and middle SP. This is deviceCountTriplets.
497 
498  The triplet filter kernel is calculated on threads equal to all possible
499  bottom-middle combinations for the first k middle SPs, which are
500  the sum of bottom-middle duplets. (For the next kernel it would be the
501  bottom-middle combinations from the (k+1)th middle SP to another jth
502  middle SP j<=M.)
503 
504  This will be numTripletFilterThreads =
505  sumBotMidPrefix[lastMiddle] - sumBotMidPrefix[firstMiddle]
506 
507  If the triplet search and triplet filter kernel finished, we continue
508  summing up possible triplet combinations from the (k+1)th middle SP.
509 
510  Inside the kernels we need to use offset because of this, to be able to
511  map threads to space point indices.
512 
513  This offset is sumCombUptoFirstMiddle.
514  */
515 
516  const auto maxMemoryAllocation =
517  std::min(edgesComb,
518  globalBufferSize / uint64_t((sizeof(detail::DeviceTriplet) +
519  sizeof(detail::SeedData)) *
520  2));
521 
522  std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceTriplet>>
523  curvImpactBuffer;
524  std::unique_ptr<vecmem::data::vector_buffer<detail::SeedData>>
525  seedArrayBuffer;
526 
527  curvImpactBuffer =
528  std::make_unique<vecmem::data::vector_buffer<detail::DeviceTriplet>>(
529  maxMemoryAllocation,
530  (device_resource ? *device_resource : resource));
531  seedArrayBuffer =
532  std::make_unique<vecmem::data::vector_buffer<detail::SeedData>>(
533  maxMemoryAllocation, 0,
534  (device_resource ? *device_resource : resource));
535 
536  copy.setup(*curvImpactBuffer);
537  copy.setup(*seedArrayBuffer);
538  // Reserve memory in advance for seed indices and weight
539  // Other way around would allocate it inside the loop
540  // -> less memory usage, but more frequent allocation and deallocation
541 
542  // Counting the seeds in the second kernel allows us to copy back the
543  // right number of seeds, and no more.
544  seeds.resize(M);
545  vecmem::vector<uint32_t> countTriplets(&resource);
546  countTriplets.resize(edgesBottom, 0);
547 
548  std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
549  deviceCountTriplets;
550  vecmem::data::vector_view<uint32_t> countTripletsView;
551 
552  if (!device_resource) {
553  countTripletsView = vecmem::get_data(countTriplets);
554  } else {
555  deviceCountTriplets =
556  std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
557  edgesBottom, *device_resource);
558  copy(vecmem::get_data(countTriplets), *deviceCountTriplets);
559  countTripletsView = vecmem::get_data(*deviceCountTriplets);
560  }
561 
562  // Do the triplet search and triplet filter for 2 sp fixed for middle
563  // space points in the interval [firstMiddle, lastMiddle).
564 
565  uint32_t lastMiddle = 0;
566  for (uint32_t firstMiddle = 0; firstMiddle < M;
567  firstMiddle = lastMiddle) {
568  // Determine the interval [firstMiddle, lastMiddle) right end based
569  // on memory requirements.
570  while (lastMiddle + 1 <= M && (sumBotTopCombPrefix[lastMiddle + 1] -
571  sumBotTopCombPrefix[firstMiddle] <
572  maxMemoryAllocation)) {
573  ++lastMiddle;
574  }
575 
576  const auto numTripletSearchThreads =
577  sumBotTopCombPrefix[lastMiddle] - sumBotTopCombPrefix[firstMiddle];
578 
579  if (numTripletSearchThreads == 0) {
580  ++lastMiddle;
581  continue;
582  }
583 
584  copy.setup(*seedArrayBuffer);
585  const auto numTripletFilterThreads =
586  sumBotMidPrefix[lastMiddle] - sumBotMidPrefix[firstMiddle];
587 
588  // Nd_range with maximum block size for triplet search and filter.
589  // (global and local range is already given)
590  cl::sycl::nd_range<1> tripletSearchNDRange =
591  calculate1DimNDRange(numTripletSearchThreads, maxWorkGroupSize);
592 
593  cl::sycl::nd_range<1> tripletFilterNDRange =
594  calculate1DimNDRange(numTripletFilterThreads, maxWorkGroupSize);
595 
596  auto tripletKernel = q->submit([&](cl::sycl::handler& h) {
597  h.depends_on({linB, linT});
598  detail::TripletSearch kernel(
599  sumBotTopCombView, numTripletSearchThreads, firstMiddle,
600  lastMiddle, *midTopDupletBuffer, sumBotMidView, sumTopMidView,
601  *linearBotBuffer, *linearTopBuffer, middleSPsView,
602  *indTopDupletBuffer, countTripletsView, seedFinderConfig,
603  *curvImpactBuffer);
604  h.parallel_for<class triplet_search_kernel>(tripletSearchNDRange,
605  kernel);
606  });
607 
608  q->submit([&](cl::sycl::handler& h) {
609  h.depends_on(tripletKernel);
610  detail::TripletFilter kernel(
611  numTripletFilterThreads, sumBotMidView, firstMiddle,
612  indMidBotCompView, *indBotDupletBuffer, sumBotTopCombView,
613  *midTopDupletBuffer, *curvImpactBuffer, topSPsView,
614  middleSPsView, bottomSPsView, countTripletsView,
615  *seedArrayBuffer, seedFinderConfig, deviceCuts);
616  h.parallel_for<class filter_2sp_fixed_kernel>(tripletFilterNDRange,
617  kernel);
618  }).wait_and_throw();
619  // sync
620  // Retrieve results from triplet search
621  std::vector<detail::SeedData> seedArray;
622  copy(*seedArrayBuffer, seedArray);
623 
624  for (auto& t : seedArray) {
625  seeds[t.middle].push_back(t);
626  }
627  }
628 
629  //************************************************//
630  // ************ TRIPLET SEARCH - END ************ //
631  //************************************************//
632  }
633 
634  } catch (cl::sycl::exception const& e) {
637  ACTS_FATAL("Caught synchronous SYCL exception:\n" << e.what())
638  throw;
639  }
640 };
641 } // namespace Acts::Sycl