Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExaTrkXBoostTrackBuildingTests.cpp
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file ExaTrkXBoostTrackBuildingTests.cpp
1 // This file is part of the Acts project.
2 //
3 // Copyright (C) 2022 CERN for the benefit of the Acts project
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <boost/test/unit_test.hpp>
10 
13 
14 #include <algorithm>
15 
16 BOOST_AUTO_TEST_CASE(test_track_building) {
17  // Make some spacepoint IDs
18  // The spacepoint ids are [100, 101, 102, ...]
19  // They should not be zero based to check if the thing also works if the
20  // spacepoint IDs do not match the node IDs used for the edges
21  std::vector<int> spacepointIds(16);
22  std::iota(spacepointIds.begin(), spacepointIds.end(), 100);
23 
24  // Build 4 tracks with 4 hits
25  std::vector<std::vector<int>> refTracks;
26  for (auto t = 0ul; t < 4; ++t) {
27  refTracks.emplace_back(spacepointIds.begin() + 4 * t,
28  spacepointIds.begin() + 4 * (t + 1));
29  }
30 
31  // Make edges
32  std::vector<int64_t> edges;
33  for (const auto &track : refTracks) {
34  for (auto it = track.begin(); it != track.end() - 1; ++it) {
35  // edges must be 0 based, so subtract 100 again
36  edges.push_back(*it - 100);
37  edges.push_back(*std::next(it) - 100);
38  }
39  }
40 
41  auto edgeTensor =
42  Acts::detail::vectorToTensor2D(edges, 2).t().contiguous().clone();
43  auto dummyWeights = torch::ones(edges.size() / 2, torch::kFloat32);
44 
45  // Run Track building
47  Acts::BoostTrackBuilding trackBuilder(std::move(logger));
48 
49  auto testTracks = trackBuilder({}, edgeTensor, dummyWeights, spacepointIds);
50 
51  // Sort tracks, so we can find them
52  std::for_each(testTracks.begin(), testTracks.end(),
53  [](auto &t) { std::sort(t.begin(), t.end()); });
54  std::for_each(refTracks.begin(), refTracks.end(),
55  [](auto &t) { std::sort(t.begin(), t.end()); });
56 
57  // Check what we have here
58  for (const auto &refTrack : refTracks) {
59  auto found = std::find(testTracks.begin(), testTracks.end(), refTrack);
60  BOOST_CHECK(found != testTracks.end());
61  }
62 }