Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch and ptr vectors for a list of tensors and nested dicts #4837

Merged
merged 17 commits into from
Jun 30, 2022

Conversation

jan-meissner
Copy link
Contributor

@jan-meissner jan-meissner commented Jun 21, 2022

Feature:

Adds batching for a list of tensors and nested dicts. Motivation comes from higher order GNNs such as Simplicial Neural Networks or CW complexes, where one often needs to batch over an unknown amount of tensors. Also added support for calculating the batch vector for nested dicts.

Example:

Batching over a dataset with data objects of the form Data(xs=[tensor1, tensor2, tensor3]) and using follow_batch = ['xs'] in the loader now yields DataBatch(xs = [3], xs_batch = [3], xs_ptr = [3]) where xs_batch is the list of batches [tensor1_batch, tensor2_batch, tensor3_batch]. Analogous for ptr.

Additional Issue fixed:

Resolved a issue concerning separate (and therefore InMemoryDataset) when batching over lists of tensors. Essentially InMemoryDataset fails to recover the data_list correctly. Here is a code snippet to reproduce the issue:

class MyTestDataset2(InMemoryDataset):
    def __init__(self, data_list):
        super().__init__('/tmp/MyTestDataset2')
        self.data, self.slices = self.collate(data_list)


def test_lists_of_tensors_in_memory_dataset():
    def tr(n, m):
        return torch.rand((n, m))

    d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)])
    d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)])
    d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)])
    d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)])

    data_list = [d1, d2, d3, d4]

    dataset = MyTestDataset2(data_list)
    assert isinstance(dataset[0].xs, list)
    assert len(dataset) == 4
    assert dataset[0].xs[1].shape == (11, 4)
    assert dataset[0].xs[2].shape == (1, 2)
    assert dataset[1].xs[0].shape == (5, 3)
    assert dataset[2].xs[1].shape == (15, 4)
    assert dataset[3].xs[1].shape == (16, 4)

@codecov
Copy link

codecov bot commented Jun 21, 2022

Codecov Report

Merging #4837 (e495331) into master (7f55f41) will decrease coverage by 1.73%.
The diff coverage is 96.87%.

❗ Current head e495331 differs from pull request most recent head 2f2f935. Consider uploading reports for the commit 2f2f935 to get more accurate results

@@            Coverage Diff             @@
##           master    #4837      +/-   ##
==========================================
- Coverage   84.52%   82.79%   -1.74%     
==========================================
  Files         329      329              
  Lines       17827    17732      -95     
==========================================
- Hits        15069    14681     -388     
- Misses       2758     3051     +293     
Impacted Files Coverage Δ
torch_geometric/data/collate.py 96.40% <95.45%> (-0.30%) ⬇️
torch_geometric/data/in_memory_dataset.py 81.15% <100.00%> (+3.88%) ⬆️
torch_geometric/data/separate.py 100.00% <100.00%> (+2.22%) ⬆️
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.51% <0.00%> (-53.00%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/resolver.py 86.04% <0.00%> (-6.98%) ⬇️
torch_geometric/io/tu.py 93.90% <0.00%> (-2.44%) ⬇️
torch_geometric/nn/models/mlp.py 98.52% <0.00%> (-1.48%) ⬇️
... and 15 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3d6eb74...2f2f935. Read the comment docs.

@jan-meissner jan-meissner marked this pull request as draft June 21, 2022 12:53
…of tensors. Modified seperate such that InMemoryDataset can deal with lists of tensors correctly. Added test for it.
@jan-meissner
Copy link
Contributor Author

jan-meissner commented Jun 21, 2022

Modified separate to deal correctly with lists of tensors. Added test case for batching lists of SparseTensors this covers now the special case in line 89 of separate.py which was not covered before.

Fixed InMemoryDataset inferring len() wrong for data with lists of tensors. Fixed this by changing nested_iter in InMemoryDataset. This works as long as the slices in the slice_dict (or slice_tree as it can also contain lists) are Tensors. This is guaranteed right now by _collate. Added a test for it.

@jan-meissner jan-meissner marked this pull request as ready for review June 21, 2022 14:22
@jan-meissner jan-meissner marked this pull request as draft June 22, 2022 15:29
@jan-meissner jan-meissner marked this pull request as ready for review June 22, 2022 16:59
@jan-meissner jan-meissner changed the title Added batch vector for lists of tensors and other recursive data structures. Added batch and ptr vector for lists of tensors and nested dicts. Jun 24, 2022
@jan-meissner jan-meissner changed the title Added batch and ptr vector for lists of tensors and nested dicts. Added batch and ptr vectors for a list of tensors and nested dicts. Jun 24, 2022
@jan-meissner jan-meissner changed the title Added batch and ptr vectors for a list of tensors and nested dicts. Added batch and ptr vectors for a list of tensors and nested dicts. Jun 24, 2022
@jan-meissner jan-meissner changed the title Added batch and ptr vectors for a list of tensors and nested dicts. Ad batch and ptr vectors for a list of tensors and nested dicts. Jun 24, 2022
@jan-meissner jan-meissner changed the title Ad batch and ptr vectors for a list of tensors and nested dicts. Add batch and ptr vectors for a list of tensors and nested dicts. Jun 24, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This looks pretty clean. If possible, I would like to move computation of batch and ptr into a single function though.

torch_geometric/data/in_memory_dataset.py Outdated Show resolved Hide resolved
torch_geometric/data/separate.py Outdated Show resolved Hide resolved
torch_geometric/data/collate.py Outdated Show resolved Hide resolved
torch_geometric/data/collate.py Outdated Show resolved Hide resolved
@rusty1s
Copy link
Member

rusty1s commented Jun 26, 2022

Please also add the change to the CHANGELOG.md.

@jan-meissner
Copy link
Contributor Author

jan-meissner commented Jun 27, 2022

Just as a heads up: I haven't done a full test.

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jan-meissner This looks super good to me. Let me know once this is ready to merge :)

@jan-meissner
Copy link
Contributor Author

jan-meissner commented Jun 29, 2022

@jan-meissner This looks super good to me. Let me know once this is ready to merge :)

@rusty1s Thanks! I'm done with it so it should be ready for merge.

@rusty1s rusty1s changed the title Add batch and ptr vectors for a list of tensors and nested dicts. Add batch and ptr vectors for a list of tensors and nested dicts Jun 30, 2022
@rusty1s rusty1s merged commit 9bb19b6 into pyg-team:master Jun 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants