From d90842d6f38e111d14353449fef8eeac54d52994 Mon Sep 17 00:00:00 2001 From: Piotr Chmiel Date: Sun, 4 Jun 2023 13:00:25 +0200 Subject: [PATCH] Introduce `disable_dynamic_shapes` experimental flag; adding its use into `to_dense_batch function` (#7246) There are devices that do not support dynamic shapes - (compiling and optimizing only static graphs). The ability to set and read the "disable_dynamic_shapes" flag allows implementors to provide static shape-friendly implementations and report user-friendly messages if it is impossible to avoid using dynamic shapes. --------- Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + test/test_experimental.py | 3 +- test/utils/test_to_dense_batch.py | 22 +++++++++++ torch_geometric/experimental.py | 51 ++++++++++++++++++++++++- torch_geometric/utils/to_dense_batch.py | 10 ++++- 5 files changed, 82 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc77cd0708c7..f1c4f940d456 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246)) - Added the option to override `use_segmm` selection in `HeteroLinear` ([#7474](https://github.com/pyg-team/pytorch_geometric/pull/7474)) - Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479)) - Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493)) diff --git a/test/test_experimental.py b/test/test_experimental.py index 39bfd71f359f..6d1cd4e513dc 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -7,8 +7,7 @@ ) -@pytest.mark.skip(reason='No experimental options available right now.') -@pytest.mark.parametrize('options', [None]) +@pytest.mark.parametrize('options', ['disable_dynamic_shapes']) def test_experimental_mode(options): assert is_experimental_mode_enabled(options) is False with experimental_mode(options): diff --git a/test/utils/test_to_dense_batch.py b/test/utils/test_to_dense_batch.py index 0c84aef88641..1611cb39fb69 100644 --- a/test/utils/test_to_dense_batch.py +++ b/test/utils/test_to_dense_batch.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +from torch_geometric.experimental import set_experimental_mode from torch_geometric.testing import onlyFullTest from torch_geometric.utils import to_dense_batch @@ -54,6 +55,27 @@ def test_to_dense_batch(fill): assert out.size() == (4, 3, 2) +def test_to_dense_batch_disable_dynamic_shapes(): + x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + batch = torch.tensor([0, 0, 1, 2, 2, 2]) + + with set_experimental_mode(True, 'disable_dynamic_shapes'): + with pytest.raises(ValueError, match="'batch_size' needs to be set"): + out, mask = to_dense_batch(x, batch, max_num_nodes=6) + with pytest.raises(ValueError, match="'max_num_nodes' needs to be"): + out, mask = to_dense_batch(x, batch, batch_size=4) + with pytest.raises(ValueError, match="'batch_size' needs to be set"): + out, mask = to_dense_batch(x) + + out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6) + assert out.size() == (1, 6, 2) + assert mask.size() == (1, 6) + + out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10) + assert out.size() == (3, 10, 2) + assert mask.size() == (3, 10) + + @onlyFullTest def test_to_dense_batch_jit(): @torch.jit.script diff --git a/torch_geometric/experimental.py b/torch_geometric/experimental.py index c4134c5aa6ae..9177f2cb80b1 100644 --- a/torch_geometric/experimental.py +++ b/torch_geometric/experimental.py @@ -1,6 +1,8 @@ -from typing import List, Optional, Union +import functools +import inspect +from typing import Any, Callable, Dict, List, Optional, Union -__experimental_flag__ = {} +__experimental_flag__ = {'disable_dynamic_shapes': False} Options = Optional[Union[str, List[str]]] @@ -77,3 +79,48 @@ def __enter__(self): def __exit__(self, *args): for option, value in self.previous_state.items(): __experimental_flag__[option] = value + + +def disable_dynamic_shapes(required_args: List[str]) -> Callable: + r"""A decorator that disables the usage of dynamic shapes for the given + arguments, i.e., it will raise an error in case :obj:`required_args` are + not passed and needs to be automatically inferred.""" + def decorator(func: Callable) -> Callable: + spec = inspect.getfullargspec(func) + + required_args_pos: Dict[str, int] = {} + for arg_name in required_args: + if arg_name not in spec.args: + raise ValueError(f"The function '{func}' does not have a " + f"'{arg_name}' argument") + required_args_pos[arg_name] = spec.args.index(arg_name) + + num_args = len(spec.args) + num_default_args = 0 if spec.defaults is None else len(spec.defaults) + num_positional_args = num_args - num_default_args + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not is_experimental_mode_enabled('disable_dynamic_shapes'): + return func(*args, **kwargs) + + for required_arg in required_args: + index = required_args_pos[required_arg] + + value: Optional[Any] = None + if index < len(args): + value = args[index] + elif required_arg in kwargs: + value = kwargs[required_arg] + elif num_default_args > 0: + value = spec.defaults[index - num_positional_args] + + if value is None: + raise ValueError(f"Dynamic shapes disabled. Argument " + f"'{required_arg}' needs to be set") + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/torch_geometric/utils/to_dense_batch.py b/torch_geometric/utils/to_dense_batch.py index 1e5538fe3c0e..c6857caf278b 100644 --- a/torch_geometric/utils/to_dense_batch.py +++ b/torch_geometric/utils/to_dense_batch.py @@ -3,9 +3,14 @@ import torch from torch import Tensor +from torch_geometric.experimental import ( + disable_dynamic_shapes, + is_experimental_mode_enabled, +) from torch_geometric.utils import scatter +@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes']) def to_dense_batch( x: Tensor, batch: Optional[Tensor] = None, @@ -106,9 +111,12 @@ def to_dense_batch( cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) filter_nodes = False + dynamic_shapes_disabled = is_experimental_mode_enabled( + 'disable_dynamic_shapes') + if max_num_nodes is None: max_num_nodes = int(num_nodes.max()) - elif num_nodes.max() > max_num_nodes: + elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: filter_nodes = True tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]