Skip to content

Commit

Permalink
Apply changes from Matthias review
Browse files Browse the repository at this point in the history
  • Loading branch information
mszarma committed Feb 27, 2023
1 parent 5209199 commit 7f56bcc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
32 changes: 16 additions & 16 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class NeighborSampler {
NeighborSampler(const scalar_t* rowptr,
const scalar_t* col,
const std::string temporal_strategy)
: rowptr_(rowptr), col_(col), temporal_strategy_(temporal_strategy), num_neighbors_counter(0) {
: rowptr_(rowptr), col_(col), temporal_strategy_(temporal_strategy), curr_hop(0) {

TORCH_CHECK(temporal_strategy == "uniform" || temporal_strategy == "last",
"No valid temporal strategy found");
Expand Down Expand Up @@ -96,8 +96,8 @@ class NeighborSampler {
}

std::vector<scalar_t>
get_sampled_edges_num() {
return sampled_edges_per_num_neighbors;
get_sampled_info() {
return sampled_edges_per_hop;
}


Expand Down Expand Up @@ -180,7 +180,7 @@ class NeighborSampler {
}
if (save_edges) {
if (return_sampled_info) {
sampled_edges_per_num_neighbors[num_neighbors_counter]++;
sampled_edges_per_hop[curr_hop]++;
}
sampled_rows_.push_back(local_src_node);
sampled_cols_.push_back(res.first);
Expand All @@ -198,8 +198,8 @@ class NeighborSampler {
std::vector<scalar_t> sampled_edge_ids_;
;
public:
scalar_t num_neighbors_counter;
std::vector<scalar_t> sampled_edges_per_num_neighbors;
scalar_t curr_hop;
std::vector<scalar_t> sampled_edges_per_hop;
};

// Homogeneous neighbor sampling ///////////////////////////////////////////////
Expand Down Expand Up @@ -230,7 +230,7 @@ sample(const at::Tensor& rowptr,
at::Tensor out_row, out_col, out_node_id;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
at::Tensor edges_num;
std::vector<int64_t> nodes_num_per_neighbor;
std::vector<int64_t> num_nodes_per_hop;
at::Tensor nodes_num;
AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] {
typedef std::pair<scalar_t, scalar_t> pair_scalar_t;
Expand All @@ -256,7 +256,7 @@ sample(const at::Tensor& rowptr,
sampled_nodes = pyg::utils::to_vector<scalar_t>(seed);
mapper.fill(seed);
if constexpr (return_sampled_info)
nodes_num_per_neighbor.push_back(seed.numel());
num_nodes_per_hop.push_back(seed.numel());
} else {
for (size_t i = 0; i < seed.numel(); ++i) {
sampled_nodes.push_back({i, seed_data[i]});
Expand All @@ -279,15 +279,15 @@ sample(const at::Tensor& rowptr,
for (size_t ell = 0; ell < num_neighbors.size(); ++ell) {
const auto count = num_neighbors[ell];
if (return_sampled_info)
sampler.sampled_edges_per_num_neighbors.push_back(0);
sampler.sampled_edges_per_hop.push_back(0);
if (!time.has_value()) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, count, mapper, generator,
/*out_global_dst_nodes=*/sampled_nodes);
}
if (return_sampled_info)
sampler.num_neighbors_counter++;
sampler.curr_hop++;
} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const auto time_data = time.value().data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
Expand All @@ -301,14 +301,14 @@ sample(const at::Tensor& rowptr,
}
begin = end, end = sampled_nodes.size();
if (return_sampled_info)
nodes_num_per_neighbor.push_back(end - begin);
num_nodes_per_hop.push_back(end - begin);
}

out_node_id = pyg::utils::from_vector(sampled_nodes);
if (return_sampled_info) {
std::vector<scalar_t> edges_num_per_neighbors = sampler.get_sampled_edges_num();
std::vector<scalar_t> edges_num_per_neighbors = sampler.get_sampled_info();
edges_num = pyg::utils::from_vector(edges_num_per_neighbors);
nodes_num = pyg::utils::from_vector(nodes_num_per_neighbor);
nodes_num = pyg::utils::from_vector(num_nodes_per_hop);
}
TORCH_CHECK(directed, "Undirected subgraphs not yet supported");
if (directed) {
Expand Down Expand Up @@ -480,14 +480,14 @@ sample(const std::vector<node_type>& node_types,

if (!time_dict.has_value() || !time_dict.value().contains(dst)) {
if (!disjoint && return_sampled_info)
sampler.sampled_edges_per_num_neighbors.push_back(0);
sampler.sampled_edges_per_hop.push_back(0);
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i, count, dst_mapper,
generator, dst_sampled_nodes);
}
if constexpr(!disjoint && return_sampled_info)
sampler.num_neighbors_counter++;
sampler.curr_hop++;

} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const at::Tensor& dst_time = time_dict.value().at(dst);
Expand Down Expand Up @@ -534,7 +534,7 @@ sample(const std::vector<node_type>& node_types,
std::get<2>(edges).value());
}
if constexpr(!disjoint && return_sampled_info) {
std::vector<scalar_t> edges_num_per_neighbors = sampler_dict.at(k).get_sampled_edges_num();
std::vector<scalar_t> edges_num_per_neighbors = sampler_dict.at(k).get_sampled_info();
edges_num_dict.insert(to_rel_type(k), pyg::utils::from_vector(edges_num_per_neighbors));
}
}
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def neighbor_sample(
return the indices of edges of the original graph.
(default: :obj: `True`)
return_sampled_info (bool): If set to :obj:`True`, will return information about
amount of sampled nodes and edges per layer.
the amount of sampled nodes and edges per layer.
(default: :obj: `False`)
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]):
Row indices, col indices of the returned subtree/subgraph, as well as
original node indices for all nodes sampled.
In addition, may return the indices of edges of the original graph,
as well as information about sampled amount of nodes and edges per layer.
as well as information about the sampled amount of nodes and edges per layer.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, seed_time, csc, replace,
Expand Down

0 comments on commit 7f56bcc

Please sign in to comment.