From aa42868684fca0cf7cb1433a3da799a64e845ac6 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 29 Dec 2022 16:07:17 +0100 Subject: [PATCH] `Data.update()` functionality (#6313) --- CHANGELOG.md | 1 + test/data/test_data.py | 11 +++++++++++ test/data/test_hetero_data.py | 21 +++++++++++++++++++++ torch_geometric/data/data.py | 10 ++++++++++ torch_geometric/data/hetero_data.py | 6 ++++++ 5 files changed, 49 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f26bf8af091..377ad94d77ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.3.0] - 2023-MM-DD ### Added +- Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313)) - Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204)) - Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287)) - Added `AttentionExplainer` ([#6279](https://github.com/pyg-team/pytorch_geometric/pull/6279)) diff --git a/test/data/test_data.py b/test/data/test_data.py index 98e1bf5f5af4..b08b606e33df 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -288,6 +288,17 @@ def my_attr1(self, value): assert data.my_attr1 == 2 +def test_data_update(): + data = Data(x=torch.arange(0, 5), y=torch.arange(5, 10)) + other = Data(z=torch.arange(10, 15), x=torch.arange(15, 20)) + data.update(other) + + assert len(data) == 3 + assert torch.equal(data.x, torch.arange(15, 20)) + assert torch.equal(data.y, torch.arange(5, 10)) + assert torch.equal(data.z, torch.arange(10, 15)) + + # Feature Store ############################################################### diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index ee63bd9361ca..01f094a43a58 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -452,6 +452,27 @@ def test_hetero_data_invalid_names(): assert data.edge_types == [('my test', 'a__b', 'my test')] +def test_hetero_data_update(): + data = HeteroData() + data['paper'].x = torch.arange(0, 5) + data['paper'].y = torch.arange(5, 10) + data['author'].x = torch.arange(10, 15) + + other = HeteroData() + other['paper'].x = torch.arange(15, 20) + other['author'].y = torch.arange(20, 25) + other['paper', 'paper'].edge_index = torch.randint(5, (2, 20)) + + data.update(other) + assert len(data) == 3 + assert torch.equal(data['paper'].x, torch.arange(15, 20)) + assert torch.equal(data['paper'].y, torch.arange(5, 10)) + assert torch.equal(data['author'].x, torch.arange(10, 15)) + assert torch.equal(data['author'].y, torch.arange(20, 25)) + assert torch.equal(data['paper', 'paper'].edge_index, + other['paper', 'paper'].edge_index) + + # Feature Store ############################################################### diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 7249c5adc103..50ae260fd26c 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -91,6 +91,11 @@ def to_namedtuple(self) -> NamedTuple: r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" raise NotImplementedError + def update(self, data: 'BaseData') -> 'BaseData': + r"""Updates the data object with the elements from another data object. + """ + raise NotImplementedError + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: r"""Returns the dimension for which the value :obj:`value` of the attribute :obj:`key` will get concatenated when creating mini-batches @@ -507,6 +512,11 @@ def to_dict(self) -> Dict[str, Any]: def to_namedtuple(self) -> NamedTuple: return self._store.to_namedtuple() + def update(self, data: 'Data') -> 'Data': + for key, value in data.items(): + self[key] = value + return self + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: if isinstance(value, SparseTensor) and 'adj' in key: return (0, 1) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 016ab652f843..46aa5ad4194a 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -274,6 +274,12 @@ def to_namedtuple(self) -> NamedTuple: DataTuple = namedtuple('DataTuple', field_names) return DataTuple(*field_values) + def update(self, data: 'HeteroData') -> 'HeteroData': + for store in data.stores: + for key, value in store.items(): + self[store._key][key] = value + return self + def __cat_dim__(self, key: str, value: Any, store: Optional[NodeOrEdgeStorage] = None, *args, **kwargs) -> Any: