26 #include "../Utilities/Arrays.hpp"
33 #include "vecmem/containers/data/jagged_vector_buffer.hpp"
34 #include "vecmem/containers/data/vector_buffer.hpp"
35 #include "vecmem/utils/sycl/copy.hpp"
38 #include <CL/sycl.hpp>
40 namespace Acts::Sycl {
42 class ind_copy_bottom_kernel;
43 class ind_copy_top_kernel;
44 class triplet_search_kernel;
45 class filter_2sp_fixed_kernel;
48 QueueWrapper wrappedQueue, vecmem::memory_resource& resource,
49 vecmem::memory_resource* device_resource,
52 vecmem::vector<detail::DeviceSpacePoint>& bottomSPs,
53 vecmem::vector<detail::DeviceSpacePoint>& middleSPs,
54 vecmem::vector<detail::DeviceSpacePoint>& topSPs,
55 std::vector<std::vector<detail::SeedData>>& seeds) {
61 const uint32_t M = middleSPs.size();
62 const uint32_t B = bottomSPs.size();
63 const uint32_t
T = topSPs.size();
69 vecmem::vector<uint32_t> sumBotMidPrefix(&resource);
70 sumBotMidPrefix.push_back(0);
71 vecmem::vector<uint32_t> sumTopMidPrefix(&resource);
72 sumTopMidPrefix.push_back(0);
73 vecmem::vector<uint32_t> sumBotTopCombPrefix(&resource);
74 sumBotTopCombPrefix.push_back(0);
80 vecmem::vector<uint32_t> indMidBotComp(&resource);
81 vecmem::vector<uint32_t> indMidTopComp(&resource);
85 uint64_t globalBufferSize =
86 q->get_device().get_info<cl::sycl::info::device::global_mem_size>();
87 uint64_t maxWorkGroupSize =
88 q->get_device().get_info<cl::sycl::info::device::max_work_group_size>();
89 vecmem::sycl::copy copy(wrappedQueue.
getQueue());
96 cl::sycl::nd_range<2> bottomDupletNDRange =
98 cl::sycl::nd_range<2> topDupletNDRange =
104 std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceSpacePoint>>
105 deviceBottomSPs, deviceTopSPs, deviceMiddleSPs;
106 vecmem::data::vector_view<detail::DeviceSpacePoint> bottomSPsView,
107 topSPsView, middleSPsView;
108 if (!device_resource) {
109 bottomSPsView = vecmem::get_data(bottomSPs);
110 topSPsView = vecmem::get_data(topSPs);
111 middleSPsView = vecmem::get_data(middleSPs);
113 deviceBottomSPs = std::make_unique<
114 vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
115 B, *device_resource);
116 deviceTopSPs = std::make_unique<
117 vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
118 T, *device_resource);
119 deviceMiddleSPs = std::make_unique<
120 vecmem::data::vector_buffer<detail::DeviceSpacePoint>>(
121 M, *device_resource);
123 copy(vecmem::get_data(bottomSPs), *deviceBottomSPs);
124 copy(vecmem::get_data(topSPs), *deviceTopSPs);
125 copy(vecmem::get_data(middleSPs), *deviceMiddleSPs);
127 bottomSPsView = vecmem::get_data(*deviceBottomSPs);
128 topSPsView = vecmem::get_data(*deviceTopSPs);
129 middleSPsView = vecmem::get_data(*deviceMiddleSPs);
136 std::unique_ptr<vecmem::data::jagged_vector_buffer<uint32_t>>
138 std::unique_ptr<vecmem::data::jagged_vector_buffer<uint32_t>>
142 std::make_unique<vecmem::data::jagged_vector_buffer<uint32_t>>(
143 std::vector<std::size_t>(M, 0), std::vector<std::size_t>(M, B),
144 (device_resource ? *device_resource : resource),
145 (device_resource ? &resource :
nullptr));
147 std::make_unique<vecmem::data::jagged_vector_buffer<uint32_t>>(
148 std::vector<std::size_t>(M, 0), std::vector<std::size_t>(M, T),
149 (device_resource ? *device_resource : resource),
150 (device_resource ? &resource :
nullptr));
151 copy.setup(*midBotDupletBuffer);
152 copy.setup(*midTopDupletBuffer);
157 middleSPsView, bottomSPsView, *midBotDupletBuffer, seedFinderConfig);
158 h.parallel_for<
class DupletSearchBottomKernel>(bottomDupletNDRange,
165 middleSPsView, topSPsView, *midTopDupletBuffer, seedFinderConfig);
166 h.parallel_for<
class DupletSearchTopKernel>(topDupletNDRange, kernel);
168 middleBottomEvent.wait_and_throw();
169 middleTopEvent.wait_and_throw();
176 auto countBotDuplets = copy.get_sizes(*midBotDupletBuffer);
177 auto countTopDuplets = copy.get_sizes(*midTopDupletBuffer);
181 for (uint32_t
i = 1;
i < M + 1; ++
i) {
182 sumBotMidPrefix.push_back(sumBotMidPrefix.at(
i - 1) +
183 countBotDuplets[
i - 1]);
184 sumTopMidPrefix.push_back(sumTopMidPrefix.at(
i - 1) +
185 countTopDuplets[
i - 1]);
186 sumBotTopCombPrefix.push_back(sumBotTopCombPrefix.at(
i - 1) +
187 countBotDuplets[
i - 1] *
188 countTopDuplets[
i - 1]);
191 const uint64_t edgesBottom = sumBotMidPrefix[M];
192 const uint64_t edgesTop = sumTopMidPrefix[M];
197 const uint64_t edgesComb = sumBotTopCombPrefix[M];
199 indMidBotComp.reserve(edgesBottom);
200 indMidTopComp.reserve(edgesTop);
203 for (uint32_t mid = 0; mid < M; ++mid) {
204 std::fill_n(std::back_inserter(indMidBotComp), countBotDuplets[mid], mid);
205 std::fill_n(std::back_inserter(indMidTopComp), countTopDuplets[mid], mid);
208 if (edgesBottom > 0 && edgesTop > 0) {
211 cl::sycl::nd_range<1> edgesBotNdRange =
215 cl::sycl::nd_range<1> edgesTopNdRange =
290 std::unique_ptr<vecmem::data::vector_buffer<uint32_t>> indBotDupletBuffer;
291 std::unique_ptr<vecmem::data::vector_buffer<uint32_t>> indTopDupletBuffer;
294 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
295 edgesBottom, (device_resource ? *device_resource : resource));
297 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
298 edgesTop, (device_resource ? *device_resource : resource));
300 copy.setup(*indBotDupletBuffer);
301 copy.setup(*indTopDupletBuffer);
304 std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
305 device_sumBotMidPrefix, device_sumTopMidPrefix,
306 device_sumBotTopCombPrefix;
308 vecmem::data::vector_view<uint32_t> sumBotMidView, sumTopMidView,
312 std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
313 device_indMidBotComp, device_indMidTopComp;
314 vecmem::data::vector_view<uint32_t> indMidBotCompView, indMidTopCompView;
317 if (!device_resource) {
318 sumBotMidView = vecmem::get_data(sumBotMidPrefix);
319 sumTopMidView = vecmem::get_data(sumTopMidPrefix);
320 sumBotTopCombView = vecmem::get_data(sumBotTopCombPrefix);
322 indMidBotCompView = vecmem::get_data(indMidBotComp);
323 indMidTopCompView = vecmem::get_data(indMidTopComp);
325 device_sumBotMidPrefix =
326 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
327 M + 1, *device_resource);
328 device_sumTopMidPrefix =
329 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
330 M + 1, *device_resource);
331 device_sumBotTopCombPrefix =
332 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
333 M + 1, *device_resource);
335 device_indMidBotComp =
336 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
337 edgesBottom, *device_resource);
338 device_indMidTopComp =
339 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
340 edgesTop, *device_resource);
342 copy(vecmem::get_data(sumBotMidPrefix), *device_sumBotMidPrefix);
343 copy(vecmem::get_data(sumTopMidPrefix), *device_sumTopMidPrefix);
344 copy(vecmem::get_data(sumBotTopCombPrefix),
345 *device_sumBotTopCombPrefix);
347 copy(vecmem::get_data(indMidBotComp), *device_indMidBotComp);
348 copy(vecmem::get_data(indMidTopComp), *device_indMidTopComp);
350 sumBotMidView = vecmem::get_data(*device_sumBotMidPrefix);
351 sumTopMidView = vecmem::get_data(*device_sumTopMidPrefix);
352 sumBotTopCombView = vecmem::get_data(*device_sumBotTopCombPrefix);
354 indMidBotCompView = vecmem::get_data(*device_indMidBotComp);
355 indMidTopCompView = vecmem::get_data(*device_indMidTopComp);
357 auto midBotDupletView = vecmem::get_data(*midBotDupletBuffer);
358 auto indBotDupletView = vecmem::get_data(*indBotDupletBuffer);
360 h.parallel_for<ind_copy_bottom_kernel>(
361 edgesBotNdRange, [=](cl::sycl::nd_item<1> item) {
362 auto idx = item.get_global_linear_id();
363 if (
idx < edgesBottom) {
364 vecmem::device_vector<uint32_t> deviceIndMidBot(
366 sumBotMidPrefix(sumBotMidView),
367 indBotDuplets(indBotDupletView);
368 vecmem::jagged_device_vector<const uint32_t> midBotDuplets(
370 auto mid = deviceIndMidBot[
idx];
371 auto ind = midBotDuplets[mid][
idx - sumBotMidPrefix[mid]];
372 indBotDuplets[
idx] = ind;
376 auto midTopDupletView = vecmem::get_data(*midTopDupletBuffer);
377 auto indTopDupletView = vecmem::get_data(*indTopDupletBuffer);
379 h.parallel_for<ind_copy_top_kernel>(
380 edgesTopNdRange, [=](cl::sycl::nd_item<1> item) {
381 auto idx = item.get_global_linear_id();
382 if (
idx < edgesTop) {
383 vecmem::device_vector<uint32_t> deviceIndMidTop(
385 sumTopMidPrefix(sumTopMidView),
386 indTopDuplets(indTopDupletView);
387 vecmem::jagged_device_vector<const uint32_t> midTopDuplets(
389 auto mid = deviceIndMidTop[
idx];
390 auto ind = midTopDuplets[mid][
idx - sumTopMidPrefix[mid]];
391 indTopDuplets[
idx] = ind;
395 indBotEvent.wait_and_throw();
396 indTopEvent.wait_and_throw();
399 std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>
401 std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>
404 linearBotBuffer = std::make_unique<
405 vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>(
406 edgesBottom, (device_resource ? *device_resource : resource));
407 linearTopBuffer = std::make_unique<
408 vecmem::data::vector_buffer<detail::DeviceLinEqCircle>>(
409 edgesTop, (device_resource ? *device_resource : resource));
411 copy.setup(*linearBotBuffer);
412 copy.setup(*linearTopBuffer);
426 middleSPsView, bottomSPsView, indMidBotCompView,
427 *indBotDupletBuffer, edgesBottom, *linearBotBuffer);
428 h.parallel_for<
class TransformCoordBottomKernel>(edgesBotNdRange,
435 middleSPsView, topSPsView, indMidTopCompView, *indTopDupletBuffer,
436 edgesTop, *linearTopBuffer);
437 h.parallel_for<
class TransformCoordTopKernel>(edgesTopNdRange, kernel);
516 const auto maxMemoryAllocation =
522 std::unique_ptr<vecmem::data::vector_buffer<detail::DeviceTriplet>>
524 std::unique_ptr<vecmem::data::vector_buffer<detail::SeedData>>
528 std::make_unique<vecmem::data::vector_buffer<detail::DeviceTriplet>>(
530 (device_resource ? *device_resource : resource));
532 std::make_unique<vecmem::data::vector_buffer<detail::SeedData>>(
533 maxMemoryAllocation, 0,
534 (device_resource ? *device_resource : resource));
536 copy.setup(*curvImpactBuffer);
537 copy.setup(*seedArrayBuffer);
545 vecmem::vector<uint32_t> countTriplets(&resource);
546 countTriplets.resize(edgesBottom, 0);
548 std::unique_ptr<vecmem::data::vector_buffer<uint32_t>>
550 vecmem::data::vector_view<uint32_t> countTripletsView;
552 if (!device_resource) {
553 countTripletsView = vecmem::get_data(countTriplets);
555 deviceCountTriplets =
556 std::make_unique<vecmem::data::vector_buffer<uint32_t>>(
557 edgesBottom, *device_resource);
558 copy(vecmem::get_data(countTriplets), *deviceCountTriplets);
559 countTripletsView = vecmem::get_data(*deviceCountTriplets);
565 uint32_t lastMiddle = 0;
566 for (uint32_t firstMiddle = 0; firstMiddle < M;
567 firstMiddle = lastMiddle) {
570 while (lastMiddle + 1 <= M && (sumBotTopCombPrefix[lastMiddle + 1] -
571 sumBotTopCombPrefix[firstMiddle] <
572 maxMemoryAllocation)) {
576 const auto numTripletSearchThreads =
577 sumBotTopCombPrefix[lastMiddle] - sumBotTopCombPrefix[firstMiddle];
579 if (numTripletSearchThreads == 0) {
584 copy.setup(*seedArrayBuffer);
585 const auto numTripletFilterThreads =
586 sumBotMidPrefix[lastMiddle] - sumBotMidPrefix[firstMiddle];
590 cl::sycl::nd_range<1> tripletSearchNDRange =
593 cl::sycl::nd_range<1> tripletFilterNDRange =
597 h.depends_on({linB, linT});
599 sumBotTopCombView, numTripletSearchThreads, firstMiddle,
600 lastMiddle, *midTopDupletBuffer, sumBotMidView, sumTopMidView,
601 *linearBotBuffer, *linearTopBuffer, middleSPsView,
602 *indTopDupletBuffer, countTripletsView, seedFinderConfig,
604 h.parallel_for<
class triplet_search_kernel>(tripletSearchNDRange,
609 h.depends_on(tripletKernel);
611 numTripletFilterThreads, sumBotMidView, firstMiddle,
612 indMidBotCompView, *indBotDupletBuffer, sumBotTopCombView,
613 *midTopDupletBuffer, *curvImpactBuffer, topSPsView,
614 middleSPsView, bottomSPsView, countTripletsView,
615 *seedArrayBuffer, seedFinderConfig, deviceCuts);
616 h.parallel_for<
class filter_2sp_fixed_kernel>(tripletFilterNDRange,
621 std::vector<detail::SeedData> seedArray;
622 copy(*seedArrayBuffer, seedArray);
624 for (
auto&
t : seedArray) {
625 seeds[
t.middle].push_back(
t);
634 }
catch (cl::sycl::exception
const&
e) {
637 ACTS_FATAL(
"Caught synchronous SYCL exception:\n" << e.what())