From ba02aa7e2372d81c466d3bfae8970634ceb26100 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Mon, 9 Sep 2024 15:26:06 +0000 Subject: [PATCH] implement `async_op` logic. --- python/dgl/graphbolt/internal/sample_utils.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/python/dgl/graphbolt/internal/sample_utils.py b/python/dgl/graphbolt/internal/sample_utils.py index 6d86d5d03063..2866e4ed379b 100644 --- a/python/dgl/graphbolt/internal/sample_utils.py +++ b/python/dgl/graphbolt/internal/sample_utils.py @@ -62,7 +62,7 @@ def unique_and_compact( list are replaced with mapped node IDs, where each type of node is mapped to a contiguous space of IDs ranging from 0 to N. The unique nodes offsets tensor partitions the unique_nodes tensor. Has - size `world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]] + size `world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank `(rank + i) % world_size`. """ is_heterogeneous = isinstance(nodes, dict) @@ -83,19 +83,40 @@ def unique_and_compact( else torch.ops.graphbolt.unique_and_compact_batched ) results = unique_fn(concat_nodes, empties, empties, rank, world_size) - unique, compacted, offsets = {}, {}, {} - for ntype, result in zip(nodes.keys(), results): - ( - unique[ntype], - concat_compacted, - _, - offsets[ntype], - ) = result - compacted[ntype] = list(concat_compacted.split(nums[ntype])) - if is_heterogeneous: - return unique, compacted, offsets + + class _Waiter: + def __init__(self, future, ntypes, nums): + self.future = future + self.ntypes = ntypes + self.nums = nums + + def wait(self): + """Returns the stored value when invoked.""" + results = self.future.wait() if async_op else self.future + ntypes = self.ntypes + nums = self.nums + # Ensure there is no memory leak. + self.future = self.ntypes = self.nums = None + + unique, compacted, offsets = {}, {}, {} + for ntype, result in zip(ntypes, results): + ( + unique[ntype], + concat_compacted, + _, + offsets[ntype], + ) = result + compacted[ntype] = list(concat_compacted.split(nums[ntype])) + if is_heterogeneous: + return unique, compacted, offsets + else: + return unique[homo_ntype], compacted[homo_ntype], offsets[homo_ntype] + + post_processer = _Waiter(results, nodes.keys(), nums) + if async_op: + return post_processer else: - return unique[homo_ntype], compacted[homo_ntype], offsets[homo_ntype] + return post_processer.wait() def compact_temporal_nodes(nodes, nodes_timestamp): @@ -218,8 +239,8 @@ def unique_and_compact_csc_formats( pairs are replaced with mapped node IDs, where each type of node is mapped to a contiguous space of IDs ranging from 0 to N. The unique nodes offsets tensor partitions the unique_nodes tensor. Has size - `world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]] belongs to - the rank `(rank + i) % world_size`. + `world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]` belongs + to the rank `(rank + i) % world_size`. Examples --------