9 #include <boost/test/unit_test.hpp>
19 #include <torch/torch.h>
25 float distance(
const at::Tensor &
a,
const at::Tensor &
b) {
26 assert(a.sizes() == b.sizes());
27 assert(a.sizes().size() == 1);
29 return std::sqrt(((a - b) * (a - b)).
sum().item().to<float>());
35 os <<
"(" <<
a <<
"," <<
b <<
")";
40 template <
typename edge_builder_t>
42 const edge_builder_t &edgeBuilder) {
44 auto random_features = at::randn({
n_nodes, emb_dim});
47 Eigen::MatrixXf distance_matrix(n_nodes, n_nodes);
49 std::vector<CantorPair> edges_ref_cantor;
50 std::vector<int> edge_counts(n_nodes, 0);
54 const auto d =
distance(random_features[
i], random_features[
j]);
55 distance_matrix(i, j) = d;
56 distance_matrix(j, i) = d;
58 if (d < r && i != j) {
59 edges_ref_cantor.emplace_back(i, j);
65 const auto max_edges =
66 *std::max_element(edge_counts.begin(), edge_counts.end());
72 BOOST_REQUIRE(max_edges <= knn);
75 auto edges_test = edgeBuilder(random_features, r, knn);
78 std::vector<CantorPair> edges_test_cantor;
80 for (
int i = 0;
i < edges_test.size(1); ++
i) {
81 const auto a = edges_test[0][
i].template item<int>();
82 const auto b = edges_test[1][
i].template item<int>();
86 std::sort(edges_ref_cantor.begin(), edges_ref_cantor.end());
87 std::sort(edges_test_cantor.begin(), edges_test_cantor.end());
90 std::cout <<
"test size " << edges_test_cantor.size() << std::endl;
91 std::cout <<
"ref size " << edges_ref_cantor.size() << std::endl;
92 std::cout <<
"test: ";
94 edges_test_cantor.begin(),
95 edges_test_cantor.begin() +
std::min(edges_test_cantor.size(), 10ul),
96 std::ostream_iterator<CantorPair>(std::cout,
" "));
97 std::cout << std::endl;
99 std::copy(edges_ref_cantor.begin(),
100 edges_ref_cantor.begin() +
std::min(edges_ref_cantor.size(), 10ul),
101 std::ostream_iterator<CantorPair>(std::cout,
" "));
102 std::cout << std::endl;
106 BOOST_CHECK(edges_ref_cantor.size() == edges_test_cantor.size());
107 BOOST_CHECK(std::equal(edges_test_cantor.begin(), edges_test_cantor.end(),
108 edges_ref_cantor.begin()));
115 const auto [aa, bb] =
CantorPair(a, b,
false).inverse();
116 BOOST_CHECK(a == aa);
117 BOOST_CHECK(b == bb);
135 *boost::unit_test::precondition([](
auto) {
136 return torch::cuda::is_available();
138 torch::manual_seed(seed);
140 auto cudaEdgeBuilder = [](
auto &features,
auto radius,
auto k) {
141 auto features_cuda = features.to(torch::kCUDA);
149 torch::manual_seed(seed);
151 auto cpuEdgeBuilder = [](
auto &features,
auto radius,
auto k) {
152 auto features_cpu = features.to(torch::kCPU);
161 std::vector<int64_t>
edges = {
169 auto opts = torch::TensorOptions().dtype(torch::kInt64);
170 const auto edgeTensor =
171 torch::from_blob(edges.data(), {
static_cast<long>(edges.size() / 2), 2},
175 const auto withoutSelfLoops =
180 const std::vector<int64_t> postEdges(
181 withoutSelfLoops.data_ptr<int64_t>(),
182 withoutSelfLoops.data_ptr<int64_t>() + withoutSelfLoops.numel());
185 const std::vector<int64_t>
ref = {
191 BOOST_CHECK(ref == postEdges);
196 std::vector<int64_t> edges = {
205 auto opts = torch::TensorOptions().dtype(torch::kInt64);
206 const auto edgeTensor =
207 torch::from_blob(edges.data(), {
static_cast<long>(edges.size() / 2), 2},
211 const auto withoutDups =
216 const std::vector<int64_t> postEdges(
217 withoutDups.data_ptr<int64_t>(),
218 withoutDups.data_ptr<int64_t>() + withoutDups.numel());
221 const std::vector<int64_t>
ref = {
228 BOOST_CHECK(ref == postEdges);
232 torch::manual_seed(seed);
235 std::vector<int64_t> edges = {
243 auto opts = torch::TensorOptions().dtype(torch::kInt64);
244 const auto edgeTensor =
245 torch::from_blob(edges.data(), {
static_cast<long>(edges.size() / 2), 2},
254 const std::vector<int64_t> postEdges(
255 flipped.data_ptr<int64_t>(),
256 flipped.data_ptr<int64_t>() + flipped.numel());
258 BOOST_CHECK(postEdges.size() == edges.size());
259 for (
auto preIt = edges.begin(); preIt != edges.end(); preIt += 2) {
262 for (
auto postIt = postEdges.begin(); postIt != postEdges.end();
264 bool noflp = (*preIt == *postIt) and *(preIt + 1) == *(postIt + 1);
265 bool flp = *preIt == *(postIt + 1) and *(preIt + 1) == *(postIt);
267 found += (flp or noflp);
270 BOOST_CHECK(found == 1);