Skip to content

Commit

Permalink
lazy ntypes/etypes calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kacper-Pietkun committed Nov 21, 2023
1 parent 68bbd5c commit 89c3f49
Showing 1 changed file with 55 additions and 30 deletions.
85 changes: 55 additions & 30 deletions sar/core/graphshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ class GraphShardManager:
:param partition_book: The graph partition information
:type partition_book: dgl.distributed.GraphPartitionBook
:param node_types: tensor with node types in local partition
:param node_types: torch.Tensor
:type node_types: torch.Tensor
:param all_shard_edges: A list of ShardEdgesAndFeatures objects. One for edges incoming from each partition
:param all_shard_edges: ShardEdgesAndFeatures
:type all_shard_edges: ShardEdgesAndFeatures
"""

def __init__(self,
Expand All @@ -200,6 +200,8 @@ def __init__(self,
super().__init__()
self.graph_shards = graph_shards
self.partition_book = partition_book
self.node_types = node_types
self.all_shard_edges = all_shard_edges

# source nodes and target nodes are all the same
# srcdata, dstdata and ndata should be also the same
Expand Down Expand Up @@ -247,35 +249,34 @@ def __init__(self,
else:
self.dstdata = ChainedDataView(self.num_dst_nodes())

# Preparing dictionaries for storing number of nodes with given type in a graph
self.src_node_types_count_dict = {}
self.dst_node_types_count_dict = {}
src_node_types_unique, src_node_types_count = torch.unique(node_types[self.input_nodes], return_counts=True)
if self.src_is_tgt:
dst_node_types_unique, dst_node_types_count = torch.unique(node_types[self.seeds], return_counts=True)
else:
# For MFGs we need to convert seeds from global to the local numbering
dst_node_types_unique, dst_node_types_count = torch.unique(node_types[self.seeds - self.tgt_node_range[0]], return_counts=True)
for ntype in self.partition_book.ntypes:
type_index = partition_book.ntypes.index(ntype)
self.src_node_types_count_dict[ntype] = src_node_types_count[(src_node_types_unique == type_index).nonzero().item()] if type_index in src_node_types_unique else 0
self.dst_node_types_count_dict[ntype] = dst_node_types_count[(dst_node_types_unique == type_index).nonzero().item()] if type_index in dst_node_types_unique else 0

# Preparing dictionary for storing number of edges with given type in a graph
self.edge_types_count_dict = {}
for type_index, canonical_etype in enumerate(self.partition_book.canonical_etypes):
if type_index not in self.edge_types_count_dict:
self.edge_types_count_dict[canonical_etype] = 0
for shard_edge in all_shard_edges:
if self.src_is_tgt:
# For not MFGs we need to convert seeds from local to global numbering
edge_mask = torch.isin(shard_edge.edges[1], self.seeds + self.tgt_node_range[0])
else:
edge_mask = torch.isin(shard_edge.edges[1], self.seeds)
self.edge_types_count_dict[canonical_etype] += (shard_edge.edge_features[dgl.ETYPE][edge_mask] == type_index).sum().item()

self._src_node_types_count_dict = {}
self._dst_node_types_count_dict = {}
self._edge_types_count_dict = {}
self._ntypes_nums_calculated = False
self._etypes_nums_calculated = False

self._sampling_graph = None

@property
def src_node_types_count_dict(self) -> Dict:
if self._ntypes_nums_calculated is False:
self._calculate_ntypes()
self._ntypes_nums_calculated = True
return self._src_node_types_count_dict

@property
def dst_node_types_count_dict(self) -> Dict:
if self._ntypes_nums_calculated is False:
self._calculate_ntypes()
self._ntypes_nums_calculated = True
return self._dst_node_types_count_dict

@property
def edge_types_count_dict(self) -> Dict:
if self._etypes_nums_calculated is False:
self._calculate_etypes()
self._etypes_nums_calculated = True
return self._edge_types_count_dict

@property
def ntypes(self) -> List[str]:
Expand Down Expand Up @@ -636,9 +637,33 @@ def out_degrees(self, vertices=dgl.ALL, etype=None) -> Tensor:
return self.out_degrees_cache[etype]

return self.out_degrees_cache[etype][vertices]

def _calculate_ntypes(self):
src_node_types_unique, src_node_types_count = torch.unique(self.node_types[self.input_nodes], return_counts=True)
if self.src_is_tgt:
dst_node_types_unique, dst_node_types_count = torch.unique(self.node_types[self.seeds], return_counts=True)
else:
# For MFGs we need to convert seeds from global to the local numbering
dst_node_types_unique, dst_node_types_count = torch.unique(self.node_types[self.seeds - self.tgt_node_range[0]], return_counts=True)
for ntype in self.partition_book.ntypes:
type_index = self.partition_book.ntypes.index(ntype)
self._src_node_types_count_dict[ntype] = src_node_types_count[(src_node_types_unique == type_index).nonzero().item()] if type_index in src_node_types_unique else 0
self._dst_node_types_count_dict[ntype] = dst_node_types_count[(dst_node_types_unique == type_index).nonzero().item()] if type_index in dst_node_types_unique else 0

def _calculate_etypes(self):
for type_index, canonical_etype in enumerate(self.partition_book.canonical_etypes):
if type_index not in self._edge_types_count_dict:
self._edge_types_count_dict[canonical_etype] = 0
for shard_edge in self.all_shard_edges:
if self.src_is_tgt:
# For not MFGs we need to convert seeds from local to global numbering
edge_mask = torch.isin(shard_edge.edges[1], self.seeds + self.tgt_node_range[0])
else:
edge_mask = torch.isin(shard_edge.edges[1], self.seeds)
self._edge_types_count_dict[canonical_etype] += (shard_edge.edge_features[dgl.ETYPE][edge_mask] == type_index).sum().item()

def _get_active_tensors(self, message_func):

def _get_active_tensors(self, message_func):
message_params = ()
if callable(message_func):
arg_spec = inspect.getfullargspec(message_func)
Expand Down

0 comments on commit 89c3f49

Please sign in to comment.