Skip to content

Commit

Permalink
Faster etype calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kacper-Pietkun committed Nov 24, 2023
1 parent 89c3f49 commit 80d28fc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 20 deletions.
6 changes: 2 additions & 4 deletions sar/construct_shard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def make_induced_graph_shard_manager(full_graph_shards: List[ShardEdgesAndFeatur
edge_type_names: List[str],
partition_book : dgl.distributed.GraphPartitionBook,
node_types: Tensor,
all_shard_edges: ShardEdgesAndFeatures,
keep_seed_nodes: bool = True) -> GraphShardManager:
'''
Creates new graph shards that only contain edges to the seed nodes. Adjusts the target
Expand Down Expand Up @@ -108,7 +107,7 @@ def make_induced_graph_shard_manager(full_graph_shards: List[ShardEdgesAndFeatur
src_range, tgt_range, edge_type_names))

return GraphShardManager(graph_shard_list, src_compact_data['local_src_seed_nodes'],
seed_nodes, partition_book, node_types, all_shard_edges)
seed_nodes, partition_book, node_types)


def compact_src_ranges(active_edges_src, seed_nodes, node_ranges, keep_seed_nodes):
Expand Down Expand Up @@ -184,7 +183,6 @@ def construct_mfgs(partition_data: PartitionData,
partition_data.edge_type_names,
partition_data.partition_book,
partition_data.node_features[dgl.NTYPE],
partition_data.all_shard_edges,
keep_seed_nodes)
graph_shard_manager_list.append(gsm)
seed_nodes = gsm.input_nodes + partition_data.node_ranges[rank()][0]
Expand All @@ -211,7 +209,7 @@ def construct_full_graph(partition_data: PartitionData) -> GraphShardManager:
seed_nodes = torch.arange(partition_data.node_ranges[rank()][1] -
partition_data.node_ranges[rank()][0])
return GraphShardManager(graph_shard_list, seed_nodes, seed_nodes, partition_data.partition_book,
partition_data.node_features[dgl.NTYPE], partition_data.all_shard_edges)
partition_data.node_features[dgl.NTYPE])

def convert_dist_graph(dist_graph: dgl.distributed.DistGraph) -> GraphShardManager:
partition_data = load_dgl_partition_data_from_graph(dist_graph, dist_graph.device)
Expand Down
22 changes: 8 additions & 14 deletions sar/core/graphshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,18 @@ class GraphShardManager:
:type partition_book: dgl.distributed.GraphPartitionBook
:param node_types: tensor with node types in local partition
:type node_types: torch.Tensor
:param all_shard_edges: A list of ShardEdgesAndFeatures objects. One for edges incoming from each partition
:type all_shard_edges: ShardEdgesAndFeatures
"""

def __init__(self,
graph_shards: List[GraphShard],
local_src_seeds: Tensor,
local_tgt_seeds: Tensor,
partition_book: dgl.distributed.GraphPartitionBook,
node_types: Tensor,
all_shard_edges: ShardEdgesAndFeatures) -> None:
node_types: Tensor) -> None:
super().__init__()
self.graph_shards = graph_shards
self.partition_book = partition_book
self.node_types = node_types
self.all_shard_edges = all_shard_edges

# source nodes and target nodes are all the same
# srcdata, dstdata and ndata should be also the same
Expand Down Expand Up @@ -651,16 +647,14 @@ def _calculate_ntypes(self):
self._dst_node_types_count_dict[ntype] = dst_node_types_count[(dst_node_types_unique == type_index).nonzero().item()] if type_index in dst_node_types_unique else 0

def _calculate_etypes(self):
for type_index, canonical_etype in enumerate(self.partition_book.canonical_etypes):
if type_index not in self._edge_types_count_dict:
for canonical_etype in self.partition_book.canonical_etypes:
self._edge_types_count_dict[canonical_etype] = 0
for shard_edge in self.all_shard_edges:
if self.src_is_tgt:
# For not MFGs we need to convert seeds from local to global numbering
edge_mask = torch.isin(shard_edge.edges[1], self.seeds + self.tgt_node_range[0])
else:
edge_mask = torch.isin(shard_edge.edges[1], self.seeds)
self._edge_types_count_dict[canonical_etype] += (shard_edge.edge_features[dgl.ETYPE][edge_mask] == type_index).sum().item()
for shard in self.graph_shards:
unique_type_ids, counts = torch.unique(shard.graph.edata[dgl.ETYPE], return_counts=True)

for idx, unique_type_id in enumerate(unique_type_ids):
canonical_etype = self.canonical_etypes_global[unique_type_id]
self._edge_types_count_dict[canonical_etype] += counts[idx].item()


def _get_active_tensors(self, message_func):
Expand Down
4 changes: 2 additions & 2 deletions sar/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def create_partition_data(graph: dgl.DGLGraph,
if dgl.NTYPE in graph.ndata:
node_features[dgl.NTYPE] = graph.ndata[dgl.NTYPE][graph.ndata['inner_node'].bool()]
else:
node_features[dgl.NTYPE] = torch.zeros(graph.num_nodes())[graph.ndata['inner_node'].bool()]
node_features[dgl.NTYPE] = torch.zeros(graph.num_nodes(), dtype=torch.int32)[graph.ndata['inner_node'].bool()]

# Include the edge types in the edge feature dictionary
inner_edge_mask = graph.edata['inner_edge'].bool()
if dgl.ETYPE in graph.edata:
edge_features[dgl.ETYPE] = graph.edata[dgl.ETYPE][inner_edge_mask]
else:
edge_features[dgl.ETYPE] = torch.zeros(graph.num_edges())[inner_edge_mask]
edge_features[dgl.ETYPE] = torch.zeros(graph.num_edges(), dtype=torch.int32)[inner_edge_mask]

# Obtain the inner edges. These are the partition edges
local_partition_edges = torch.stack(graph.all_edges())[:, inner_edge_mask]
Expand Down

0 comments on commit 80d28fc

Please sign in to comment.