Skip to content

Commit

Permalink
Adding option to return info about sampled graph (#197)
Browse files Browse the repository at this point in the history
It's enabling the hierarchical tensor usage
and significant performance improvement

PyG part: pyg-team/pytorch_geometric#6661

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Feb 28, 2023
1 parent 1760817 commit c04fb60
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.2.0] - 2023-MM-DD
### Added
- `neighbor_sample` routines now also return information about the number of sampled nodes/edges per layer ([#197](https://github.com/pyg-team/pyg-lib/pull/197))
- Added `index_sort` implementation ([#181](https://github.com/pyg-team/pyg-lib/pull/181), [#192](https://github.com/pyg-team/pyg-lib/pull/192))
- Added `triton>=2.0` support ([#171](https://github.com/pyg-team/pyg-lib/pull/171))
- Added `bias` term to `grouped_matmul` and `segment_matmul` ([#161](https://github.com/pyg-team/pyg-lib/pull/161))
Expand Down
59 changes: 51 additions & 8 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class NeighborSampler {
out_global_dst_nodes.push_back(global_dst_node);
}
if (save_edges) {
num_sampled_edges_per_hop[num_sampled_edges_per_hop.size() - 1]++;
sampled_rows_.push_back(local_src_node);
sampled_cols_.push_back(res.first);
if (save_edge_ids) {
Expand All @@ -184,12 +185,20 @@ class NeighborSampler {
std::vector<scalar_t> sampled_rows_;
std::vector<scalar_t> sampled_cols_;
std::vector<scalar_t> sampled_edge_ids_;

public:
std::vector<int64_t> num_sampled_edges_per_hop;
};

// Homogeneous neighbor sampling ///////////////////////////////////////////////

template <bool replace, bool directed, bool disjoint, bool return_edge_id>
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
Expand All @@ -214,6 +223,8 @@ sample(const at::Tensor& rowptr,

at::Tensor out_row, out_col, out_node_id;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
std::vector<int64_t> num_sampled_nodes_per_hop;
std::vector<int64_t> num_sampled_edges_per_hop;

AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] {
typedef std::pair<scalar_t, scalar_t> pair_scalar_t;
Expand Down Expand Up @@ -255,10 +266,12 @@ sample(const at::Tensor& rowptr,
}
}

num_sampled_nodes_per_hop.push_back(seed.numel());

size_t begin = 0, end = seed.size(0);
for (size_t ell = 0; ell < num_neighbors.size(); ++ell) {
const auto count = num_neighbors[ell];

sampler.num_sampled_edges_per_hop.push_back(0);
if (!time.has_value()) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i],
Expand All @@ -277,18 +290,22 @@ sample(const at::Tensor& rowptr,
}
}
begin = end, end = sampled_nodes.size();
num_sampled_nodes_per_hop.push_back(end - begin);
}

out_node_id = pyg::utils::from_vector(sampled_nodes);

TORCH_CHECK(directed, "Undirected subgraphs not yet supported");
if (directed) {
std::tie(out_row, out_col, out_edge_id) = sampler.get_sampled_edges(csc);
} else {
TORCH_CHECK(!disjoint, "Disjoint subgraphs not yet supported");
}

num_sampled_edges_per_hop = sampler.num_sampled_edges_per_hop;
});
return std::make_tuple(out_row, out_col, out_node_id, out_edge_id);

return std::make_tuple(out_row, out_col, out_node_id, out_edge_id,
num_sampled_nodes_per_hop, num_sampled_edges_per_hop);
}

// Heterogeneous neighbor sampling /////////////////////////////////////////////
Expand All @@ -297,7 +314,9 @@ template <bool replace, bool directed, bool disjoint, bool return_edge_id>
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
sample(const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
Expand Down Expand Up @@ -344,6 +363,10 @@ sample(const std::vector<node_type>& node_types,
} else {
out_edge_id_dict = c10::nullopt;
}
std::unordered_map<node_type, std::vector<int64_t>>
num_sampled_nodes_per_hop_map;
c10::Dict<node_type, std::vector<int64_t>> num_sampled_nodes_per_hop_dict;
c10::Dict<rel_type, std::vector<int64_t>> num_sampled_edges_per_hop_dict;

const auto scalar_type = seed_dict.begin()->value().scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "hetero_sample_kernel", [&] {
Expand Down Expand Up @@ -378,6 +401,7 @@ sample(const std::vector<node_type>& node_types,
for (const auto& k : node_types) {
const auto N = num_nodes_dict.count(k) > 0 ? num_nodes_dict.at(k) : 0;
sampled_nodes_dict[k]; // Initialize empty vector.
num_sampled_nodes_per_hop_map.insert({k, std::vector<int64_t>(1, 0)});
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
}
Expand Down Expand Up @@ -422,6 +446,9 @@ sample(const std::vector<node_type>& node_types,
}
}
}

num_sampled_nodes_per_hop_map.at(kv.key())[0] =
sampled_nodes_dict.at(kv.key()).size();
}

size_t begin, end;
Expand All @@ -436,6 +463,8 @@ sample(const std::vector<node_type>& node_types,
auto& sampler = sampler_dict.at(k);
std::tie(begin, end) = slice_dict.at(src);

sampler.num_sampled_edges_per_hop.push_back(0);

if (!time_dict.has_value() || !time_dict.value().contains(dst)) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/src_sampled_nodes[i],
Expand All @@ -457,12 +486,16 @@ sample(const std::vector<node_type>& node_types,
for (const auto& k : node_types) {
slice_dict[k] = {slice_dict.at(k).second,
sampled_nodes_dict.at(k).size()};
num_sampled_nodes_per_hop_map.at(k).push_back(slice_dict.at(k).second -
slice_dict.at(k).first);
}
}

for (const auto& k : node_types) {
out_node_id_dict.insert(
k, pyg::utils::from_vector(sampled_nodes_dict.at(k)));
num_sampled_nodes_per_hop_dict.insert(
k, num_sampled_nodes_per_hop_map.at(k));
}

TORCH_CHECK(directed, "Undirected heterogeneous graphs not yet supported");
Expand All @@ -471,6 +504,8 @@ sample(const std::vector<node_type>& node_types,
const auto edges = sampler_dict.at(k).get_sampled_edges(csc);
out_row_dict.insert(to_rel_type(k), std::get<0>(edges));
out_col_dict.insert(to_rel_type(k), std::get<1>(edges));
num_sampled_edges_per_hop_dict.insert(
to_rel_type(k), sampler_dict.at(k).num_sampled_edges_per_hop);
if (return_edge_id) {
out_edge_id_dict.value().insert(to_rel_type(k),
std::get<2>(edges).value());
Expand All @@ -480,7 +515,8 @@ sample(const std::vector<node_type>& node_types,
});

return std::make_tuple(out_row_dict, out_col_dict, out_node_id_dict,
out_edge_id_dict);
out_edge_id_dict, num_sampled_nodes_per_hop_dict,
num_sampled_edges_per_hop_dict);
}

// Dispatcher //////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -521,7 +557,12 @@ sample(const std::vector<node_type>& node_types,

} // namespace

std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
Expand All @@ -541,7 +582,9 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
Expand Down
11 changes: 9 additions & 2 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
namespace pyg {
namespace sampler {

std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
Expand All @@ -22,7 +27,9 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
Expand Down
16 changes: 12 additions & 4 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
namespace pyg {
namespace sampler {

std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
Expand Down Expand Up @@ -40,7 +45,9 @@ neighbor_sample(const at::Tensor& rowptr,
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
Expand Down Expand Up @@ -90,7 +97,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc "
"= False, bool replace = False, bool directed = True, bool disjoint = "
"False, str temporal_strategy = 'uniform', bool return_edge_id = True) "
"-> (Tensor, Tensor, Tensor, Tensor?)"));
"-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
Expand All @@ -99,7 +106,8 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"= None, bool csc = False, bool replace = False, bool directed = True, "
"bool disjoint = False, str temporal_strategy = 'uniform', bool "
"return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor), Dict(str, Tensor)?)"));
"Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), "
"Dict(str, int[]))"));
}

} // namespace sampler
Expand Down
11 changes: 9 additions & 2 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ namespace sampler {
// in the graph given by `(rowptr, col)`.
// Returns: (row, col, node_id, edge_id)
PYG_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
Expand All @@ -32,7 +37,9 @@ PYG_API
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
Expand Down
34 changes: 22 additions & 12 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def neighbor_sample(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]:
r"""Recursively samples neighbors from all node indices in :obj:`seed`
in the graph given by :obj:`(rowptr, col)`.
Expand Down Expand Up @@ -64,10 +64,13 @@ def neighbor_sample(
(default: :obj: `True`)
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]):
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
List[int], List[int]):
Row indices, col indices of the returned subtree/subgraph, as well as
original node indices for all nodes sampled.
In addition, may return the indices of edges of the original graph.
Lastly, returns information about the sampled amount of nodes and edges
per hop.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, seed_time, csc, replace,
Expand All @@ -89,7 +92,8 @@ def hetero_neighbor_sample(
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[
NodeType, Tensor], Optional[Dict[EdgeType, Tensor]]]:
NodeType, Tensor], Optional[Dict[EdgeType, Tensor]], Dict[
NodeType, List[int]], Dict[NodeType, List[int]]]:
r"""Recursively samples neighbors from all node indices in :obj:`seed_dict`
in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`.
Expand Down Expand Up @@ -133,16 +137,22 @@ def hetero_neighbor_sample(
return_edge_id,
)

out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict = out
out_row_dict = {TO_EDGE_TYPE[k]: v for k, v in out_row_dict.items()}
out_col_dict = {TO_EDGE_TYPE[k]: v for k, v in out_col_dict.items()}
if out_edge_id_dict is not None:
out_edge_id_dict = {
TO_EDGE_TYPE[k]: v
for k, v in out_edge_id_dict.items()
}
(row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict,
num_edges_per_hop_dict) = out

return out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict
row_dict = {TO_EDGE_TYPE[k]: v for k, v in row_dict.items()}
col_dict = {TO_EDGE_TYPE[k]: v for k, v in col_dict.items()}

if edge_id_dict is not None:
edge_id_dict = {TO_EDGE_TYPE[k]: v for k, v in edge_id_dict.items()}

num_edges_per_hop_dict = {
TO_EDGE_TYPE[k]: v
for k, v in num_edges_per_hop_dict.items()
}

return (row_dict, col_dict, node_id_dict, edge_id_dict,
num_nodes_per_hop_dict, num_edges_per_hop_dict)


def subgraph(
Expand Down
8 changes: 8 additions & 0 deletions test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ TEST(FullNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes));
auto expected_edges = at::tensor({4, 5, 6, 7, 2, 3, 8, 9}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
std::vector<int64_t> expected_num_nodes = {2, 2, 2};
EXPECT_TRUE(std::get<4>(out) == expected_num_nodes);
std::vector<int64_t> expected_num_edges = {4, 4};
EXPECT_TRUE(std::get<5>(out) == expected_num_edges);
}

TEST(WithoutReplacementNeighborTest, BasicAssertions) {
Expand Down Expand Up @@ -168,4 +172,8 @@ TEST(HeteroNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<2>(out).at(node_key), expected_nodes));
auto expected_edges = at::tensor({4, 5, 6, 7, 2, 3, 8, 9}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value().at(rel_key), expected_edges));
std::vector<int64_t> expected_num_nodes = {2, 2, 2};
EXPECT_TRUE(std::get<4>(out).at("paper") == expected_num_nodes);
std::vector<int64_t> expected_num_edges = {4, 4};
EXPECT_TRUE(std::get<5>(out).at("paper__to__paper") == expected_num_edges);
}

0 comments on commit c04fb60

Please sign in to comment.