From 3c1e5bb3ea8f6b588d597a296c9d479101b6bc95 Mon Sep 17 00:00:00 2001 From: Piotr Chmiel Date: Thu, 27 Apr 2023 12:10:47 +0100 Subject: [PATCH] Introduce "disable_dynamic_shapes" experimental flag. Adding its use into to_dense_batch function. There are devices that do not support dynamic shapes - (compiling and optimizing only static graphs). The ability to set and read the "disable_dynamic_shapes" flag allows implementors to provide static shape-friendly implementations and report user-friendly messages if it is impossible to avoid using dynamic shapes. --- CHANGELOG.md | 4 ++-- test/test_experimental.py | 3 +-- test/utils/test_to_dense_batch.py | 27 +++++++++++++++++++++++++ torch_geometric/experimental.py | 2 +- torch_geometric/utils/to_dense_batch.py | 20 +++++++++++++++--- 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 163647b09ef0d..d22ae56509305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.jit.script` within `MessagePassing` layers without `torch_sparse` being installed ([#7061](https://github.com/pyg-team/pytorch_geometric/pull/7061), [#7062](https://github.com/pyg-team/pytorch_geometric/pull/7062)) - Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037)) - Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026)) - +- Added `disable_dynamic_shape` exeperimental flag ([#7246]https://github.com/pyg-team/pytorch_geometric/pull/7246) ### Changed - Fixed `HGTConv` utility function `_construct_src_node_feat` ([#7194](https://github.com/pyg-team/pytorch_geometric/pull/7194)) @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043)) - Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041)) - Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239)) - +- Added usage `disable_dynamic_shape` exeperimental flag in `to_dense_batch` function ([#7246]https://github.com/pyg-team/pytorch_geometric/pull/7246) ### Removed - Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117)) diff --git a/test/test_experimental.py b/test/test_experimental.py index 39bfd71f359f1..6d1cd4e513dc6 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -7,8 +7,7 @@ ) -@pytest.mark.skip(reason='No experimental options available right now.') -@pytest.mark.parametrize('options', [None]) +@pytest.mark.parametrize('options', ['disable_dynamic_shapes']) def test_experimental_mode(options): assert is_experimental_mode_enabled(options) is False with experimental_mode(options): diff --git a/test/utils/test_to_dense_batch.py b/test/utils/test_to_dense_batch.py index 958542880591b..a8bf40d5c9c85 100644 --- a/test/utils/test_to_dense_batch.py +++ b/test/utils/test_to_dense_batch.py @@ -1,5 +1,7 @@ +import pytest import torch +from torch_geometric.experimental import set_experimental_mode from torch_geometric.testing import is_full_test from torch_geometric.utils import to_dense_batch @@ -52,3 +54,28 @@ def test_to_dense_batch(): out, mask = to_dense_batch(x, batch, batch_size=4) assert out.size() == (4, 3, 2) + + with set_experimental_mode(True, "disable_dynamic_shapes"): + with pytest.raises(RuntimeError): + out, mask = to_dense_batch(x, batch, batch_size=4) + with pytest.raises(RuntimeError): + out, mask = to_dense_batch(x, batch, batch_size=4) + + out, mask = to_dense_batch(x) + + assert out.size() == (1, 6, 2) + assert out.tolist() == [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], + [11, 12]]] + assert mask.tolist() == [[1, 1, 1, 1, 1, 1]] + + out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10) + assert out.size() == (3, 10, 2) + assert torch.equal(out[0, :3], + torch.Tensor([[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]])) + assert torch.equal(out[1, :3], + torch.Tensor([[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]])) + assert torch.equal( + out[2, :3], torch.Tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]])) + assert mask.tolist() == [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0]] diff --git a/torch_geometric/experimental.py b/torch_geometric/experimental.py index c4134c5aa6aec..cf857b81c7404 100644 --- a/torch_geometric/experimental.py +++ b/torch_geometric/experimental.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union -__experimental_flag__ = {} +__experimental_flag__ = {'disable_dynamic_shapes': False} Options = Optional[Union[str, List[str]]] diff --git a/torch_geometric/utils/to_dense_batch.py b/torch_geometric/utils/to_dense_batch.py index a13bc06db1d6b..1799c2f54e5e0 100644 --- a/torch_geometric/utils/to_dense_batch.py +++ b/torch_geometric/utils/to_dense_batch.py @@ -3,6 +3,7 @@ import torch from torch import Tensor +from torch_geometric.experimental import is_experimental_mode_enabled from torch_geometric.utils import scatter @@ -92,8 +93,16 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) + dynamic_shapes_disabled = is_experimental_mode_enabled( + "disable_dynamic_shapes") + if batch_size is None: - batch_size = int(batch.max()) + 1 + if dynamic_shapes_disabled: + raise RuntimeError( + "Dynamic shapes disabled. `batch_size` argument should be set." + ) + else: + batch_size = int(batch.max()) + 1 num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size, reduce='sum') @@ -101,8 +110,13 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, filter_nodes = False if max_num_nodes is None: - max_num_nodes = int(num_nodes.max()) - elif num_nodes.max() > max_num_nodes: + if dynamic_shapes_disabled: + raise RuntimeError( + "Dynamic shapes disabled. `max_num_nodes` argument should be " + "set.") + else: + max_num_nodes = int(num_nodes.max()) + elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: filter_nodes = True tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]