15 #include <boost/beast/core/span.hpp>
16 #include <boost/graph/adjacency_list.hpp>
17 #include <boost/graph/connected_components.hpp>
18 #include <torch/torch.h>
21 template <
typename vertex_t,
typename weight_t>
23 boost::beast::span<vertex_t>& rowIndices,
24 boost::beast::span<vertex_t>& colIndices,
25 boost::beast::span<weight_t>& edgeWeights,
26 std::vector<vertex_t>& trackLabels) {
28 boost::adjacency_list<boost::vecS,
37 for (
const auto [row,
col, weight] :
38 Acts::zip(rowIndices, colIndices, edgeWeights)) {
39 boost::add_edge(row,
col, weight,
g);
42 return boost::connected_components(
g, &trackLabels[0]);
49 std::any nodes, std::any
edges, std::any weights,
50 std::vector<int>& spacepointIDs,
int) {
52 const auto edgeTensor = std::any_cast<torch::Tensor>(
edges).to(torch::kCPU);
53 const auto edgeWeightTensor =
54 std::any_cast<torch::Tensor>(weights).to(torch::kCPU);
56 assert(edgeTensor.size(0) == 2);
57 assert(edgeTensor.size(1) == edgeWeightTensor.size(0));
59 const auto numSpacepoints = spacepointIDs.size();
60 const auto numEdges =
static_cast<std::size_t
>(edgeWeightTensor.size(0));
63 ACTS_WARNING(
"No edges remained after edge classification");
67 using vertex_t = int64_t;
68 using weight_t = float;
70 boost::beast::span<vertex_t> rowIndices(edgeTensor.data_ptr<vertex_t>(),
72 boost::beast::span<vertex_t> colIndices(
73 edgeTensor.data_ptr<vertex_t>() + numEdges, numEdges);
74 boost::beast::span<weight_t> edgeWeights(edgeWeightTensor.data_ptr<
float>(),
77 std::vector<vertex_t> trackLabels(numSpacepoints);
79 auto numberLabels = weaklyConnectedComponents<vertex_t, weight_t>(
80 numSpacepoints, rowIndices, colIndices, edgeWeights, trackLabels);
82 ACTS_VERBOSE(
"Number of track labels: " << trackLabels.size());
83 ACTS_VERBOSE(
"Number of unique track labels: " << [&]() {
84 std::vector<vertex_t> sorted(trackLabels);
86 sorted.erase(std::unique(sorted.begin(), sorted.end()), sorted.end());
90 if (trackLabels.size() == 0) {
94 std::vector<std::vector<int>> trackCandidates(numberLabels);
96 for (
const auto [
label,
id] :
Acts::zip(trackLabels, spacepointIDs)) {
97 trackCandidates[
label].push_back(
id);
100 return trackCandidates;