13 #include <torch/script.h>
20 std::any, std::any
edges, std::any edge_weights,
21 std::vector<int> &spacepointIDs,
int) {
22 auto numSpacepoints = spacepointIDs.size();
23 auto edgesAfterFiltering = std::any_cast<std::vector<int64_t>>(
edges);
24 auto numEdgesAfterF = edgesAfterFiltering.size() / 2;
25 auto gOutputCTen = std::any_cast<at::Tensor>(edge_weights);
27 if (numEdgesAfterF == 0) {
34 std::vector<int32_t> rowIndices;
35 std::vector<int32_t> colIndices;
36 std::vector<float> edgeWeights;
37 std::vector<int32_t> trackLabels(numSpacepoints);
38 std::copy(edgesAfterFiltering.begin(),
39 edgesAfterFiltering.begin() + numEdgesAfterF,
40 std::back_insert_iterator(rowIndices));
41 std::copy(edgesAfterFiltering.begin() + numEdgesAfterF,
42 edgesAfterFiltering.end(), std::back_insert_iterator(colIndices));
43 std::copy(gOutputCTen.data_ptr<
float>(),
44 gOutputCTen.data_ptr<
float>() + numEdgesAfterF,
45 std::back_insert_iterator(edgeWeights));
48 weaklyConnectedComponents<int32_t, int32_t, float>(
49 rowIndices, colIndices, edgeWeights, trackLabels,
logger());
51 ACTS_DEBUG(
"size of components: " << trackLabels.size());
52 if (trackLabels.size() == 0) {
56 std::vector<std::vector<int>> trackCandidates;
57 trackCandidates.clear();
61 std::map<int, int> trackLableToIds;
63 for (
auto idx = 0ul;
idx < numSpacepoints; ++
idx) {
64 int trackLabel = trackLabels[
idx];
65 int spacepointID = spacepointIDs[
idx];
68 if (trackLableToIds.find(trackLabel) != trackLableToIds.end()) {
69 trkId = trackLableToIds[trackLabel];
70 trackCandidates[trkId].push_back(spacepointID);
75 trackCandidates.push_back(std::vector<int>{trkId});
76 trackLableToIds[trackLabel] = trkId;
81 return trackCandidates;