Skip to content

Commit

Permalink
implement async_op logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 9, 2024
1 parent b010268 commit ba02aa7
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions python/dgl/graphbolt/internal/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def unique_and_compact(
list are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N.
The unique nodes offsets tensor partitions the unique_nodes tensor. Has
size `world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]]
size `world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]`
belongs to the rank `(rank + i) % world_size`.
"""
is_heterogeneous = isinstance(nodes, dict)
Expand All @@ -83,19 +83,40 @@ def unique_and_compact(
else torch.ops.graphbolt.unique_and_compact_batched
)
results = unique_fn(concat_nodes, empties, empties, rank, world_size)
unique, compacted, offsets = {}, {}, {}
for ntype, result in zip(nodes.keys(), results):
(
unique[ntype],
concat_compacted,
_,
offsets[ntype],
) = result
compacted[ntype] = list(concat_compacted.split(nums[ntype]))
if is_heterogeneous:
return unique, compacted, offsets

class _Waiter:
def __init__(self, future, ntypes, nums):
self.future = future
self.ntypes = ntypes
self.nums = nums

def wait(self):
"""Returns the stored value when invoked."""
results = self.future.wait() if async_op else self.future
ntypes = self.ntypes
nums = self.nums
# Ensure there is no memory leak.
self.future = self.ntypes = self.nums = None

unique, compacted, offsets = {}, {}, {}
for ntype, result in zip(ntypes, results):
(
unique[ntype],
concat_compacted,
_,
offsets[ntype],
) = result
compacted[ntype] = list(concat_compacted.split(nums[ntype]))
if is_heterogeneous:
return unique, compacted, offsets
else:
return unique[homo_ntype], compacted[homo_ntype], offsets[homo_ntype]

post_processer = _Waiter(results, nodes.keys(), nums)
if async_op:
return post_processer
else:
return unique[homo_ntype], compacted[homo_ntype], offsets[homo_ntype]
return post_processer.wait()


def compact_temporal_nodes(nodes, nodes_timestamp):
Expand Down Expand Up @@ -218,8 +239,8 @@ def unique_and_compact_csc_formats(
pairs are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N. The unique
nodes offsets tensor partitions the unique_nodes tensor. Has size
`world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]] belongs to
the rank `(rank + i) % world_size`.
`world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]` belongs
to the rank `(rank + i) % world_size`.
Examples
--------
Expand Down

0 comments on commit ba02aa7

Please sign in to comment.