From 36debdf47a9d6aef76f4a5c5d49ff12b6aa75c69 Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Wed, 15 Feb 2023 13:54:58 -0800 Subject: [PATCH 1/2] Improve code cov for graph store --- test/data/test_graph_store.py | 21 +++++++++++++++++++++ torch_geometric/data/graph_store.py | 5 +++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/data/test_graph_store.py b/test/data/test_graph_store.py index 1b2465b6c206..98d635225deb 100644 --- a/test/data/test_graph_store.py +++ b/test/data/test_graph_store.py @@ -2,6 +2,7 @@ import torch from torch_sparse import SparseTensor +from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout from torch_geometric.testing import MyGraphStore @@ -14,6 +15,8 @@ def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): def test_graph_store(): graph_store = MyGraphStore() + assert str(graph_store) == 'MyGraphStore()' + coo = torch.tensor([0, 1]), torch.tensor([1, 2]) csr = torch.tensor([0, 1, 2]), torch.tensor([1, 2]) csc = torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2]) @@ -31,6 +34,10 @@ def test_graph_store(): assert len(graph_store.get_all_edge_attrs()) == 3 + del graph_store['edge_type', 'coo'] + with pytest.raises(KeyError): + graph_store['edge_type', 'coo'] + with pytest.raises(KeyError): graph_store['edge_type_2', 'coo'] @@ -73,3 +80,17 @@ def test_graph_store_conversion(): out = graph_store.coo([('v', '1', 'v')]) assert torch.equal(list(out[0].values())[0], coo[0]) assert torch.equal(list(out[1].values())[0], coo[1]) + + # Ensure that 'store' parameter works as intended: + key = EdgeAttr(edge_type=('v', '1', 'v'), layout=EdgeLayout.CSR, + is_sorted=False, size=(100, 100)) + with pytest.raises(KeyError): + graph_store[key] + + out = graph_store.csr([('v', '1', 'v')], store=True) + assert torch.equal(list(out[0].values())[0], csr[0]) + assert torch.equal(list(out[1].values())[0].sort()[0], csr[1].sort()[0]) + + out = graph_store[key] + assert torch.equal(out[0], csr[0]) + assert torch.equal(out[1].sort()[0], csr[1].sort()[0]) diff --git a/torch_geometric/data/graph_store.py b/torch_geometric/data/graph_store.py index d8efefa03dc3..8ad626de6d9d 100644 --- a/torch_geometric/data/graph_store.py +++ b/torch_geometric/data/graph_store.py @@ -126,8 +126,9 @@ def put_edge_index(self, edge_index: EdgeTensorType, *args, Returns whether insertion was successful. Args: - tensor (Tuple[torch.Tensor, torch.Tensor]): The :obj:`edge_index` - tuple in a format specified in :class:`EdgeAttr`. + edge_index (Tuple[torch.Tensor, torch.Tensor]): The + :obj:`edge_index` tuple in a format specified in + :class:`EdgeAttr`. **kwargs (EdgeAttr): Any relevant edge attributes that correspond to the :obj:`edge_index` tuple. See the :class:`EdgeAttr` documentation for required and optional From feb347ad4b8777e5dc592881918ee75319bca14d Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Wed, 15 Feb 2023 14:06:32 -0800 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 175f2ad00252..e7255ea9ace7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685)) - Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)) - Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609)) -- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703)) +- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720))) - Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522)) - Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517)) - Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))