Skip to content

Commit

Permalink
Data.update() functionality (#6313)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 29, 2022
1 parent 77d38cb commit aa42868
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313))
- Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204))
- Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287))
- Added `AttentionExplainer` ([#6279](https://github.com/pyg-team/pytorch_geometric/pull/6279))
Expand Down
11 changes: 11 additions & 0 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ def my_attr1(self, value):
assert data.my_attr1 == 2


def test_data_update():
data = Data(x=torch.arange(0, 5), y=torch.arange(5, 10))
other = Data(z=torch.arange(10, 15), x=torch.arange(15, 20))
data.update(other)

assert len(data) == 3
assert torch.equal(data.x, torch.arange(15, 20))
assert torch.equal(data.y, torch.arange(5, 10))
assert torch.equal(data.z, torch.arange(10, 15))


# Feature Store ###############################################################


Expand Down
21 changes: 21 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,27 @@ def test_hetero_data_invalid_names():
assert data.edge_types == [('my test', 'a__b', 'my test')]


def test_hetero_data_update():
data = HeteroData()
data['paper'].x = torch.arange(0, 5)
data['paper'].y = torch.arange(5, 10)
data['author'].x = torch.arange(10, 15)

other = HeteroData()
other['paper'].x = torch.arange(15, 20)
other['author'].y = torch.arange(20, 25)
other['paper', 'paper'].edge_index = torch.randint(5, (2, 20))

data.update(other)
assert len(data) == 3
assert torch.equal(data['paper'].x, torch.arange(15, 20))
assert torch.equal(data['paper'].y, torch.arange(5, 10))
assert torch.equal(data['author'].x, torch.arange(10, 15))
assert torch.equal(data['author'].y, torch.arange(20, 25))
assert torch.equal(data['paper', 'paper'].edge_index,
other['paper', 'paper'].edge_index)


# Feature Store ###############################################################


Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def to_namedtuple(self) -> NamedTuple:
r"""Returns a :obj:`NamedTuple` of stored key/value pairs."""
raise NotImplementedError

def update(self, data: 'BaseData') -> 'BaseData':
r"""Updates the data object with the elements from another data object.
"""
raise NotImplementedError

def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
r"""Returns the dimension for which the value :obj:`value` of the
attribute :obj:`key` will get concatenated when creating mini-batches
Expand Down Expand Up @@ -507,6 +512,11 @@ def to_dict(self) -> Dict[str, Any]:
def to_namedtuple(self) -> NamedTuple:
return self._store.to_namedtuple()

def update(self, data: 'Data') -> 'Data':
for key, value in data.items():
self[key] = value
return self

def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
if isinstance(value, SparseTensor) and 'adj' in key:
return (0, 1)
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def to_namedtuple(self) -> NamedTuple:
DataTuple = namedtuple('DataTuple', field_names)
return DataTuple(*field_values)

def update(self, data: 'HeteroData') -> 'HeteroData':
for store in data.stores:
for key, value in store.items():
self[store._key][key] = value
return self

def __cat_dim__(self, key: str, value: Any,
store: Optional[NodeOrEdgeStorage] = None, *args,
**kwargs) -> Any:
Expand Down

0 comments on commit aa42868

Please sign in to comment.