From 99b70c45c3feecc2fbeec9a67e1c1cf572f6918b Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Fri, 15 Dec 2023 14:37:27 +0000 Subject: [PATCH 1/5] enable hetero distributed node-level and edge-level temporal sampling --- .../distributed/test_dist_neighbor_sampler.py | 207 +++++++++++++++++- .../distributed/dist_neighbor_sampler.py | 41 +++- torch_geometric/distributed/utils.py | 4 + torch_geometric/sampler/neighbor_sampler.py | 10 +- 4 files changed, 247 insertions(+), 15 deletions(-) diff --git a/test/distributed/test_dist_neighbor_sampler.py b/test/distributed/test_dist_neighbor_sampler.py index 5bc26ee95820..90796784034f 100644 --- a/test/distributed/test_dist_neighbor_sampler.py +++ b/test/distributed/test_dist_neighbor_sampler.py @@ -73,7 +73,8 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None): return (feature_store, graph_store), data -def create_hetero_data(tmp_path: str, rank: int): +def create_hetero_data(tmp_path: str, rank: int, + time_attr: Optional[str] = None): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) ( @@ -89,6 +90,27 @@ def create_hetero_data(tmp_path: str, rank: int): graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb graph_store.meta = feature_store.meta = meta + if time_attr == 'time': # Create node-level time data: + feature_store.put_tensor( + tensor=torch.full((len(out[3]['v0']), ), 1, dtype=torch.int64), + group_name='v0', + attr_name=time_attr, + ) + feature_store.put_tensor( + tensor=torch.full((len(out[3]['v1']), ), 2, dtype=torch.int64), + group_name='v1', + attr_name=time_attr, + ) + elif time_attr == 'edge_time': # Create edge-level time data: + for i, (attr, + edge_index) in enumerate(graph_store._edge_index.items()): + feature_store.put_tensor( + tensor=torch.full((edge_index.size(1), ), i, + dtype=torch.int64), + group_name=attr[0], + attr_name=time_attr, + ) + return feature_store, graph_store @@ -313,6 +335,87 @@ def dist_neighbor_sampler_hetero( ) +def dist_neighbor_sampler_temporal_hetero( + data: FakeHeteroDataset, + tmp_path: str, + world_size: int, + rank: int, + master_port: int, + input_type: str, + seed_time: torch.tensor = None, + temporal_strategy: str = 'uniform', + time_attr: str = 'time', +): + dist_data = create_hetero_data(tmp_path, rank, time_attr) + + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-sampler-test', + ) + + dist_sampler = DistNeighborSampler( + data=dist_data, + current_ctx=current_ctx, + rpc_worker_names={}, + num_neighbors=[-1, -1], + shuffle=False, + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + ) + + # Close RPC & worker group at exit: + atexit.register(close_sampler, 0, dist_sampler) + + init_rpc( + current_ctx=current_ctx, + rpc_worker_names={}, + master_addr='localhost', + master_port=master_port, + ) + + dist_sampler.register_sampler_rpc() + dist_sampler.init_event_loop() + + # Create inputs nodes such that each belongs to a different partition: + node_pb_list = dist_data[1].node_pb[input_type].tolist() + node_0 = node_pb_list.index(0) + node_1 = node_pb_list.index(1) + + input_node = torch.tensor([node_0, node_1], dtype=torch.int64) + + inputs = NodeSamplerInput( + input_id=None, + node=input_node, + time=seed_time, + input_type=input_type, + ) + + # Evaluate distributed node sample function: + out_dist = dist_sampler.event_loop.run_task( + coro=dist_sampler.node_sample(inputs)) + + sampler = NeighborSampler( + data=data, + num_neighbors=[-1, -1], + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + ) + + # Evaluate node sample function: + out = node_sample(inputs, sampler._sample) + + # Compare distributed output with single machine output: + for k in data.node_types: + assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0]) + assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0]) + assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k] + + @onlyLinux @withPackage('pyg_lib') @pytest.mark.parametrize('disjoint', [False, True]) @@ -437,3 +540,105 @@ def test_dist_neighbor_sampler_hetero(tmp_path, disjoint): w1.start() w0.join() w1.join() + + +@withPackage('pyg_lib') +@pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]]) +@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) +def test_dist_neighbor_sampler_temporal_hetero(tmp_path, seed_time, + temporal_strategy): + if seed_time is not None: + seed_time = torch.tensor(seed_time) + + mp_context = torch.multiprocessing.get_context('spawn') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + + world_size = 2 + data = FakeHeteroDataset( + num_graphs=1, + avg_num_nodes=100, + avg_degree=3, + num_node_types=2, + num_edge_types=4, + edge_dim=2, + )[0] + + partitioner = Partitioner(data, world_size, tmp_path) + partitioner.generate_partition() + + # The partition generation script does not currently support temporal data. + # Therefore, it needs to be added after generating partitions. + data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1, + dtype=torch.int64) + data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2, + dtype=torch.int64) + + w0 = mp_context.Process( + target=dist_neighbor_sampler_temporal_hetero, + args=(data, tmp_path, world_size, 0, port, 'v0', seed_time, + temporal_strategy, 'time'), + ) + + w1 = mp_context.Process( + target=dist_neighbor_sampler_temporal_hetero, + args=(data, tmp_path, world_size, 1, port, 'v1', seed_time, + temporal_strategy, 'time'), + ) + + w0.start() + w1.start() + w0.join() + w1.join() + + +@withPackage('pyg_lib') +@pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]]) +@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) +def test_dist_neighbor_sampler_edge_level_temporal_hetero( + tmp_path, seed_time, temporal_strategy): + seed_time = torch.tensor(seed_time) + + mp_context = torch.multiprocessing.get_context('spawn') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + + world_size = 2 + data = FakeHeteroDataset( + num_graphs=1, + avg_num_nodes=100, + avg_degree=3, + num_node_types=2, + num_edge_types=4, + edge_dim=2, + )[0] + + partitioner = Partitioner(data, world_size, tmp_path) + partitioner.generate_partition() + + # The partition generation script does not currently support temporal data. + # Therefore, it needs to be added after generating partitions. + for i, edge_type in enumerate(data.edge_types): + data[edge_type].edge_time = torch.full( + (data[edge_type].edge_index.size(1), ), i, dtype=torch.int64) + + w0 = mp_context.Process( + target=dist_neighbor_sampler_temporal_hetero, + args=(data, tmp_path, world_size, 0, port, 'v0', seed_time, + temporal_strategy, 'edge_time'), + ) + + w1 = mp_context.Process( + target=dist_neighbor_sampler_temporal_hetero, + args=(data, tmp_path, world_size, 1, port, 'v1', seed_time, + temporal_strategy, 'edge_time'), + ) + + w0.start() + w1.start() + w0.join() + w1.join() diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index 04845bff3c0f..b8b34d10def5 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -228,16 +228,19 @@ async def node_sample( else: raise ValueError("Seed time needs to be specified") + # Heterogeneus Neighborhood Sampling ################################## + if self.is_hetero: if input_type is None: raise ValueError("Input type should be defined") - seed_dict: Dict[NodeType, Tensor] = {input_type: seed} - seed_time_dict: Dict[NodeType, Tensor] = {input_type: seed_time} - node_dict = NodeDict(self.node_types, self.num_hops) batch_dict = BatchDict(self.node_types, self.num_hops) + seed_dict: Dict[NodeType, Tensor] = {input_type: seed} + node_dict.seed_time[input_type][0] = seed_time.clone( + ) if self.temporal else None + edge_dict: Dict[EdgeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) for k in self.edge_types @@ -281,8 +284,6 @@ async def node_sample( num_sampled_edges_dict[edge_type].append(0) continue - seed_time = (seed_time_dict or {}).get(src, None) - if isinstance(self.num_neighbors, list): one_hop_num = self.num_neighbors[i] else: @@ -292,7 +293,7 @@ async def node_sample( out = await self.sample_one_hop( node_dict.src[src][i], one_hop_num, - seed_time, + node_dict.seed_time[src][i], batch_dict.src[src][i], edge_type, ) @@ -306,9 +307,9 @@ async def node_sample( # Remove duplicates: ( - node_src, + src_node, node_dict.out[dst], - batch_src, + src_batch, batch_dict.out[dst], ) = remove_duplicates( out, @@ -319,10 +320,10 @@ async def node_sample( # Create src nodes for the next layer: node_dict.src[dst][i + 1] = torch.cat( - [node_dict.src[dst][i + 1], node_src]) + [node_dict.src[dst][i + 1], src_node]) if self.disjoint: batch_dict.src[dst][i + 1] = torch.cat( - [batch_dict.src[dst][i + 1], batch_src]) + [batch_dict.src[dst][i + 1], src_batch]) # Save sampled nodes with duplicates to be able to create # local edge indices: @@ -336,6 +337,18 @@ async def node_sample( batch_dict.with_dupl[dst] = torch.cat( [batch_dict.with_dupl[dst], out.batch]) + if self.temporal and i < self.num_hops - 1: + # Assign seed time based on source nodes subgraph IDs. + src_seed_time = [ + seed_time[(seed_batch == batch_idx).nonzero()] + for batch_idx in src_batch + ] + src_seed_time = torch.as_tensor( + src_seed_time, dtype=torch.int64) + + node_dict.seed_time[dst][i + 1] = torch.cat( + [node_dict.seed_time[dst][i + 1], src_seed_time]) + # Collect sampled neighbors per node for each layer: sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0] @@ -371,6 +384,9 @@ async def node_sample( num_sampled_edges=num_sampled_edges_dict, metadata=metadata, ) + + # Homogeneus Neighborhood Sampling #################################### + else: src = seed node = src.clone() @@ -797,8 +813,9 @@ def _sample_one_hop( rel_type = '__'.join(edge_type) colptr = self._sampler.colptr_dict[rel_type] row = self._sampler.row_dict[rel_type] - node_time = (self.node_time or {}).get(edge_type[2], None) - edge_time = (self.edge_time or {}).get(edge_type[2], None) + # `node_time` is a destination node time: + node_time = (self.node_time or {}).get(edge_type[0], None) + edge_time = (self.edge_time or {}).get(edge_type, None) out = torch.ops.pyg.dist_neighbor_sample( colptr, diff --git a/torch_geometric/distributed/utils.py b/torch_geometric/distributed/utils.py index 6b9f8196e6ea..295785a8652c 100644 --- a/torch_geometric/distributed/utils.py +++ b/torch_geometric/distributed/utils.py @@ -29,6 +29,10 @@ def __init__(self, node_types, num_hops): k: torch.empty(0, dtype=torch.int64) for k in node_types } + self.seed_time: Dict[NodeType, List[Tensor]] = { + k: num_hops * [torch.empty(0, dtype=torch.int64)] + for k in node_types + } class BatchDict: diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 33b694432a1e..40779196fe9b 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -234,16 +234,22 @@ def __init__( self.node_time: Optional[Dict[NodeType, Tensor]] = None self.edge_time: Optional[Dict[NodeType, Tensor]] = None - # TODO Add support for edge-level temporal sampling. if time_attr is not None: for attr in time_attrs: # Reset index for full data. attr.index = None time_tensors = feature_store.multi_get_tensor(time_attrs) - self.node_time = { + + # Currently, we determine whether to use node-level or + # edge-level temporal sampling based on the attribute name. + time = { time_attr.group_name: time_tensor for time_attr, time_tensor in zip( time_attrs, time_tensors) } + if time_attr == 'time': + self.node_time = time + else: + self.edge_time = time # Conversion to/from C++ string type (see above): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} From 30607896532002547ba471cbd376494946704885 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Fri, 15 Dec 2023 16:36:48 +0000 Subject: [PATCH 2/5] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3e838ec0761..f16653904ef5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399)) - Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369)) - Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375)) +- Update `DistNeighborSampler` for heterogeneous graphs ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624)) - Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083)) - Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210)) - Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) From ff919a818777c12b160d8f6eaa5bb23f13152138 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Thu, 21 Dec 2023 10:34:20 +0000 Subject: [PATCH 3/5] `init_event_loop` return None --- torch_geometric/distributed/dist_neighbor_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index b8b34d10def5..744c87baaccd 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -126,7 +126,7 @@ def register_sampler_rpc(self) -> None: rpc_sample_callee = RPCSamplingCallee(self) self.rpc_sample_callee_id = rpc_register(rpc_sample_callee) - def init_event_loop(self): + def init_event_loop(self) -> None: if self.event_loop is None: self.event_loop = ConcurrentEventLoop(self.concurrency) self.event_loop.start_loop() From af46c1165c0d2eb55dafaaf7cf67a7e61f41d28c Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Thu, 21 Dec 2023 10:53:22 +0000 Subject: [PATCH 4/5] test update --- test/distributed/test_dist_neighbor_sampler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_dist_neighbor_sampler.py b/test/distributed/test_dist_neighbor_sampler.py index 90796784034f..d865c15e23c2 100644 --- a/test/distributed/test_dist_neighbor_sampler.py +++ b/test/distributed/test_dist_neighbor_sampler.py @@ -92,12 +92,12 @@ def create_hetero_data(tmp_path: str, rank: int, if time_attr == 'time': # Create node-level time data: feature_store.put_tensor( - tensor=torch.full((len(out[3]['v0']), ), 1, dtype=torch.int64), + tensor=torch.full((len(node_pb['v0']), ), 1, dtype=torch.int64), group_name='v0', attr_name=time_attr, ) feature_store.put_tensor( - tensor=torch.full((len(out[3]['v1']), ), 2, dtype=torch.int64), + tensor=torch.full((len(node_pb['v1']), ), 2, dtype=torch.int64), group_name='v1', attr_name=time_attr, ) @@ -368,7 +368,7 @@ def dist_neighbor_sampler_temporal_hetero( ) # Close RPC & worker group at exit: - atexit.register(close_sampler, 0, dist_sampler) + atexit.register(shutdown_rpc) init_rpc( current_ctx=current_ctx, @@ -377,8 +377,10 @@ def dist_neighbor_sampler_temporal_hetero( master_port=master_port, ) + dist_sampler.init_sampler_instance() dist_sampler.register_sampler_rpc() - dist_sampler.init_event_loop() + dist_sampler.event_loop = ConcurrentEventLoop(2) + dist_sampler.event_loop.start_loop() # Create inputs nodes such that each belongs to a different partition: node_pb_list = dist_data[1].node_pb[input_type].tolist() From 215a0d99cd1156c875c059750c089ecaa97d3202 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 30 Dec 2023 19:20:41 +0000 Subject: [PATCH 5/5] update --- CHANGELOG.md | 3 +- .../distributed/test_dist_neighbor_sampler.py | 28 ++++++++++++------ .../distributed/dist_neighbor_sampler.py | 10 +++---- torch_geometric/sampler/neighbor_sampler.py | 29 +++++++++++-------- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c94549063c07..23a89a7b84d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519)) - Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399)) - Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369)) -- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375)) -- Update `DistNeighborSampler` for heterogeneous graphs ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624)) +- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624)) - Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083)) - Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210)) - Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) diff --git a/test/distributed/test_dist_neighbor_sampler.py b/test/distributed/test_dist_neighbor_sampler.py index d865c15e23c2..778c8999c1c8 100644 --- a/test/distributed/test_dist_neighbor_sampler.py +++ b/test/distributed/test_dist_neighbor_sampler.py @@ -73,8 +73,11 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None): return (feature_store, graph_store), data -def create_hetero_data(tmp_path: str, rank: int, - time_attr: Optional[str] = None): +def create_hetero_data( + tmp_path: str, + rank: int, + time_attr: Optional[str] = None, +): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) ( @@ -102,14 +105,15 @@ def create_hetero_data(tmp_path: str, rank: int, attr_name=time_attr, ) elif time_attr == 'edge_time': # Create edge-level time data: - for i, (attr, - edge_index) in enumerate(graph_store._edge_index.items()): + i = 0 + for attr, edge_index in graph_store._edge_index.items(): + time = torch.full((edge_index.size(1), ), i, dtype=torch.int64) feature_store.put_tensor( - tensor=torch.full((edge_index.size(1), ), i, - dtype=torch.int64), + tensor=time, group_name=attr[0], attr_name=time_attr, ) + i += 1 return feature_store, graph_store @@ -547,8 +551,11 @@ def test_dist_neighbor_sampler_hetero(tmp_path, disjoint): @withPackage('pyg_lib') @pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) -def test_dist_neighbor_sampler_temporal_hetero(tmp_path, seed_time, - temporal_strategy): +def test_dist_neighbor_sampler_temporal_hetero( + tmp_path, + seed_time, + temporal_strategy, +): if seed_time is not None: seed_time = torch.tensor(seed_time) @@ -600,7 +607,10 @@ def test_dist_neighbor_sampler_temporal_hetero(tmp_path, seed_time, @pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]]) @pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) def test_dist_neighbor_sampler_edge_level_temporal_hetero( - tmp_path, seed_time, temporal_strategy): + tmp_path, + seed_time, + temporal_strategy, +): seed_time = torch.tensor(seed_time) mp_context = torch.multiprocessing.get_context('spawn') diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index 744c87baaccd..0d61f85230c9 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -228,7 +228,7 @@ async def node_sample( else: raise ValueError("Seed time needs to be specified") - # Heterogeneus Neighborhood Sampling ################################## + # Heterogeneous Neighborhood Sampling ################################# if self.is_hetero: if input_type is None: @@ -238,8 +238,8 @@ async def node_sample( batch_dict = BatchDict(self.node_types, self.num_hops) seed_dict: Dict[NodeType, Tensor] = {input_type: seed} - node_dict.seed_time[input_type][0] = seed_time.clone( - ) if self.temporal else None + if self.temporal: + node_dict.seed_time[input_type][0] = seed_time.clone() edge_dict: Dict[EdgeType, Tensor] = { k: torch.empty(0, dtype=torch.int64) @@ -338,7 +338,7 @@ async def node_sample( [batch_dict.with_dupl[dst], out.batch]) if self.temporal and i < self.num_hops - 1: - # Assign seed time based on source nodes subgraph IDs. + # Assign seed time based on source node subgraph ID: src_seed_time = [ seed_time[(seed_batch == batch_idx).nonzero()] for batch_idx in src_batch @@ -385,7 +385,7 @@ async def node_sample( metadata=metadata, ) - # Homogeneus Neighborhood Sampling #################################### + # Homogeneous Neighborhood Sampling ################################### else: src = seed diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 40779196fe9b..d77ecd4741a3 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -222,9 +222,11 @@ def __init__( self.row, self.colptr, self.perm = graph_store.csc() else: - self.node_types = list( - set(attr.group_name for attr in attrs - if isinstance(attr.group_name, NodeType))) + node_types = [ + attr.group_name for attr in attrs + if isinstance(attr.group_name, str) + ] + self.node_types = list(set(node_types)) self.num_nodes = { node_type: remote_backend_utils.size(*data, node_type) for node_type in self.node_types @@ -232,24 +234,27 @@ def __init__( self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None self.node_time: Optional[Dict[NodeType, Tensor]] = None - self.edge_time: Optional[Dict[NodeType, Tensor]] = None + self.edge_time: Optional[Dict[EdgeType, Tensor]] = None if time_attr is not None: for attr in time_attrs: # Reset index for full data. attr.index = None - time_tensors = feature_store.multi_get_tensor(time_attrs) - # Currently, we determine whether to use node-level or - # edge-level temporal sampling based on the attribute name. + time_tensors = feature_store.multi_get_tensor(time_attrs) time = { - time_attr.group_name: time_tensor - for time_attr, time_tensor in zip( - time_attrs, time_tensors) + attr.group_name: time_tensor + for attr, time_tensor in zip(time_attrs, time_tensors) } - if time_attr == 'time': + + group_names = [attr.group_name for attr in time_attrs] + if all([isinstance(g, str) for g in group_names]): self.node_time = time - else: + elif all([isinstance(g, tuple) for g in group_names]): self.edge_time = time + else: + raise ValueError( + f"Found time attribute '{time_attr}' for both " + f"node-level and edge-level types") # Conversion to/from C++ string type (see above): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}