diff --git a/CHANGELOG.md b/CHANGELOG.md index c2a1b04b0..6c7adf614 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.2.0] - 2023-MM-DD ### Added +- `neighbor_sample` routines now also return information about the number of sampled nodes/edges per layer ([#197](https://github.com/pyg-team/pyg-lib/pull/197)) - Added `index_sort` implementation ([#181](https://github.com/pyg-team/pyg-lib/pull/181), [#192](https://github.com/pyg-team/pyg-lib/pull/192)) - Added `triton>=2.0` support ([#171](https://github.com/pyg-team/pyg-lib/pull/171)) - Added `bias` term to `grouped_matmul` and `segment_matmul` ([#161](https://github.com/pyg-team/pyg-lib/pull/161)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 27d370be3..4f91974e4 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -170,6 +170,7 @@ class NeighborSampler { out_global_dst_nodes.push_back(global_dst_node); } if (save_edges) { + num_sampled_edges_per_hop[num_sampled_edges_per_hop.size() - 1]++; sampled_rows_.push_back(local_src_node); sampled_cols_.push_back(res.first); if (save_edge_ids) { @@ -184,12 +185,20 @@ class NeighborSampler { std::vector sampled_rows_; std::vector sampled_cols_; std::vector sampled_edge_ids_; + + public: + std::vector num_sampled_edges_per_hop; }; // Homogeneous neighbor sampling /////////////////////////////////////////////// template -std::tuple> +std::tuple, + std::vector, + std::vector> sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, @@ -214,6 +223,8 @@ sample(const at::Tensor& rowptr, at::Tensor out_row, out_col, out_node_id; c10::optional out_edge_id = c10::nullopt; + std::vector num_sampled_nodes_per_hop; + std::vector num_sampled_edges_per_hop; AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] { typedef std::pair pair_scalar_t; @@ -255,10 +266,12 @@ sample(const at::Tensor& rowptr, } } + num_sampled_nodes_per_hop.push_back(seed.numel()); + size_t begin = 0, end = seed.size(0); for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { const auto count = num_neighbors[ell]; - + sampler.num_sampled_edges_per_hop.push_back(0); if (!time.has_value()) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], @@ -277,18 +290,22 @@ sample(const at::Tensor& rowptr, } } begin = end, end = sampled_nodes.size(); + num_sampled_nodes_per_hop.push_back(end - begin); } out_node_id = pyg::utils::from_vector(sampled_nodes); - TORCH_CHECK(directed, "Undirected subgraphs not yet supported"); if (directed) { std::tie(out_row, out_col, out_edge_id) = sampler.get_sampled_edges(csc); } else { TORCH_CHECK(!disjoint, "Disjoint subgraphs not yet supported"); } + + num_sampled_edges_per_hop = sampler.num_sampled_edges_per_hop; }); - return std::make_tuple(out_row, out_col, out_node_id, out_edge_id); + + return std::make_tuple(out_row, out_col, out_node_id, out_edge_id, + num_sampled_nodes_per_hop, num_sampled_edges_per_hop); } // Heterogeneous neighbor sampling ///////////////////////////////////////////// @@ -297,7 +314,9 @@ template std::tuple, c10::Dict, c10::Dict, - c10::optional>> + c10::optional>, + c10::Dict>, + c10::Dict>> sample(const std::vector& node_types, const std::vector& edge_types, const c10::Dict& rowptr_dict, @@ -344,6 +363,10 @@ sample(const std::vector& node_types, } else { out_edge_id_dict = c10::nullopt; } + std::unordered_map> + num_sampled_nodes_per_hop_map; + c10::Dict> num_sampled_nodes_per_hop_dict; + c10::Dict> num_sampled_edges_per_hop_dict; const auto scalar_type = seed_dict.begin()->value().scalar_type(); AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "hetero_sample_kernel", [&] { @@ -378,6 +401,7 @@ sample(const std::vector& node_types, for (const auto& k : node_types) { const auto N = num_nodes_dict.count(k) > 0 ? num_nodes_dict.at(k) : 0; sampled_nodes_dict[k]; // Initialize empty vector. + num_sampled_nodes_per_hop_map.insert({k, std::vector(1, 0)}); mapper_dict.insert({k, Mapper(N)}); slice_dict[k] = {0, 0}; } @@ -422,6 +446,9 @@ sample(const std::vector& node_types, } } } + + num_sampled_nodes_per_hop_map.at(kv.key())[0] = + sampled_nodes_dict.at(kv.key()).size(); } size_t begin, end; @@ -436,6 +463,8 @@ sample(const std::vector& node_types, auto& sampler = sampler_dict.at(k); std::tie(begin, end) = slice_dict.at(src); + sampler.num_sampled_edges_per_hop.push_back(0); + if (!time_dict.has_value() || !time_dict.value().contains(dst)) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample(/*global_src_node=*/src_sampled_nodes[i], @@ -457,12 +486,16 @@ sample(const std::vector& node_types, for (const auto& k : node_types) { slice_dict[k] = {slice_dict.at(k).second, sampled_nodes_dict.at(k).size()}; + num_sampled_nodes_per_hop_map.at(k).push_back(slice_dict.at(k).second - + slice_dict.at(k).first); } } for (const auto& k : node_types) { out_node_id_dict.insert( k, pyg::utils::from_vector(sampled_nodes_dict.at(k))); + num_sampled_nodes_per_hop_dict.insert( + k, num_sampled_nodes_per_hop_map.at(k)); } TORCH_CHECK(directed, "Undirected heterogeneous graphs not yet supported"); @@ -471,6 +504,8 @@ sample(const std::vector& node_types, const auto edges = sampler_dict.at(k).get_sampled_edges(csc); out_row_dict.insert(to_rel_type(k), std::get<0>(edges)); out_col_dict.insert(to_rel_type(k), std::get<1>(edges)); + num_sampled_edges_per_hop_dict.insert( + to_rel_type(k), sampler_dict.at(k).num_sampled_edges_per_hop); if (return_edge_id) { out_edge_id_dict.value().insert(to_rel_type(k), std::get<2>(edges).value()); @@ -480,7 +515,8 @@ sample(const std::vector& node_types, }); return std::make_tuple(out_row_dict, out_col_dict, out_node_id_dict, - out_edge_id_dict); + out_edge_id_dict, num_sampled_nodes_per_hop_dict, + num_sampled_edges_per_hop_dict); } // Dispatcher ////////////////////////////////////////////////////////////////// @@ -521,7 +557,12 @@ sample(const std::vector& node_types, } // namespace -std::tuple> +std::tuple, + std::vector, + std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, @@ -541,7 +582,9 @@ neighbor_sample_kernel(const at::Tensor& rowptr, std::tuple, c10::Dict, c10::Dict, - c10::optional>> + c10::optional>, + c10::Dict>, + c10::Dict>> hetero_neighbor_sample_kernel( const std::vector& node_types, const std::vector& edge_types, diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 64b5cc44f..0e0a532f2 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -5,7 +5,12 @@ namespace pyg { namespace sampler { -std::tuple> +std::tuple, + std::vector, + std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, @@ -22,7 +27,9 @@ neighbor_sample_kernel(const at::Tensor& rowptr, std::tuple, c10::Dict, c10::Dict, - c10::optional>> + c10::optional>, + c10::Dict>, + c10::Dict>> hetero_neighbor_sample_kernel( const std::vector& node_types, const std::vector& edge_types, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 9d4b3c9b2..f0550b78a 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -8,7 +8,12 @@ namespace pyg { namespace sampler { -std::tuple> +std::tuple, + std::vector, + std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, @@ -40,7 +45,9 @@ neighbor_sample(const at::Tensor& rowptr, std::tuple, c10::Dict, c10::Dict, - c10::optional>> + c10::optional>, + c10::Dict>, + c10::Dict>> hetero_neighbor_sample( const std::vector& node_types, const std::vector& edge_types, @@ -90,7 +97,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " "= False, bool replace = False, bool directed = True, bool disjoint = " "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " - "-> (Tensor, Tensor, Tensor, Tensor?)")); + "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " @@ -99,7 +106,8 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "= None, bool csc = False, bool replace = False, bool directed = True, " "bool disjoint = False, str temporal_strategy = 'uniform', bool " "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " - "Dict(str, Tensor), Dict(str, Tensor)?)")); + "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " + "Dict(str, int[]))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 881f1d8cc..55114450a 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -11,7 +11,12 @@ namespace sampler { // in the graph given by `(rowptr, col)`. // Returns: (row, col, node_id, edge_id) PYG_API -std::tuple> +std::tuple, + std::vector, + std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, @@ -32,7 +37,9 @@ PYG_API std::tuple, c10::Dict, c10::Dict, - c10::optional>> + c10::optional>, + c10::Dict>, + c10::Dict>> hetero_neighbor_sample( const std::vector& node_types, const std::vector& edge_types, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 1e1628cb5..2cfea286c 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -21,7 +21,7 @@ def neighbor_sample( disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]: r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. @@ -64,10 +64,13 @@ def neighbor_sample( (default: :obj: `True`) Returns: - (torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]): + (torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], + List[int], List[int]): Row indices, col indices of the returned subtree/subgraph, as well as original node indices for all nodes sampled. In addition, may return the indices of edges of the original graph. + Lastly, returns information about the sampled amount of nodes and edges + per hop. """ return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, time, seed_time, csc, replace, @@ -89,7 +92,8 @@ def hetero_neighbor_sample( temporal_strategy: str = 'uniform', return_edge_id: bool = True, ) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[ - NodeType, Tensor], Optional[Dict[EdgeType, Tensor]]]: + NodeType, Tensor], Optional[Dict[EdgeType, Tensor]], Dict[ + NodeType, List[int]], Dict[NodeType, List[int]]]: r"""Recursively samples neighbors from all node indices in :obj:`seed_dict` in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`. @@ -133,16 +137,22 @@ def hetero_neighbor_sample( return_edge_id, ) - out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict = out - out_row_dict = {TO_EDGE_TYPE[k]: v for k, v in out_row_dict.items()} - out_col_dict = {TO_EDGE_TYPE[k]: v for k, v in out_col_dict.items()} - if out_edge_id_dict is not None: - out_edge_id_dict = { - TO_EDGE_TYPE[k]: v - for k, v in out_edge_id_dict.items() - } + (row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict, + num_edges_per_hop_dict) = out - return out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict + row_dict = {TO_EDGE_TYPE[k]: v for k, v in row_dict.items()} + col_dict = {TO_EDGE_TYPE[k]: v for k, v in col_dict.items()} + + if edge_id_dict is not None: + edge_id_dict = {TO_EDGE_TYPE[k]: v for k, v in edge_id_dict.items()} + + num_edges_per_hop_dict = { + TO_EDGE_TYPE[k]: v + for k, v in num_edges_per_hop_dict.items() + } + + return (row_dict, col_dict, node_id_dict, edge_id_dict, + num_nodes_per_hop_dict, num_edges_per_hop_dict) def subgraph( diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index ded75e488..2c3c75709 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -24,6 +24,10 @@ TEST(FullNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); auto expected_edges = at::tensor({4, 5, 6, 7, 2, 3, 8, 9}, options); EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + std::vector expected_num_nodes = {2, 2, 2}; + EXPECT_TRUE(std::get<4>(out) == expected_num_nodes); + std::vector expected_num_edges = {4, 4}; + EXPECT_TRUE(std::get<5>(out) == expected_num_edges); } TEST(WithoutReplacementNeighborTest, BasicAssertions) { @@ -168,4 +172,8 @@ TEST(HeteroNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<2>(out).at(node_key), expected_nodes)); auto expected_edges = at::tensor({4, 5, 6, 7, 2, 3, 8, 9}, options); EXPECT_TRUE(at::equal(std::get<3>(out).value().at(rel_key), expected_edges)); + std::vector expected_num_nodes = {2, 2, 2}; + EXPECT_TRUE(std::get<4>(out).at("paper") == expected_num_nodes); + std::vector expected_num_edges = {4, 4}; + EXPECT_TRUE(std::get<5>(out).at("paper__to__paper") == expected_num_edges); }