Skip to content

Commit

Permalink
Introduce disable_dynamic_shapes experimental flag; adding its use …
Browse files Browse the repository at this point in the history
…into `to_dense_batch function` (#7246)

There are devices that do not support dynamic shapes - (compiling and
optimizing only static graphs). The ability to set and read the
"disable_dynamic_shapes" flag allows implementors to provide static
shape-friendly implementations and report user-friendly messages if it
is impossible to avoid using dynamic shapes.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
piotrchmiel and rusty1s authored Jun 4, 2023
1 parent c1bffa1 commit d90842d
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246))
- Added the option to override `use_segmm` selection in `HeteroLinear` ([#7474](https://github.com/pyg-team/pytorch_geometric/pull/7474))
- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493))
Expand Down
3 changes: 1 addition & 2 deletions test/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
)


@pytest.mark.skip(reason='No experimental options available right now.')
@pytest.mark.parametrize('options', [None])
@pytest.mark.parametrize('options', ['disable_dynamic_shapes'])
def test_experimental_mode(options):
assert is_experimental_mode_enabled(options) is False
with experimental_mode(options):
Expand Down
22 changes: 22 additions & 0 deletions test/utils/test_to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import set_experimental_mode
from torch_geometric.testing import onlyFullTest
from torch_geometric.utils import to_dense_batch

Expand Down Expand Up @@ -54,6 +55,27 @@ def test_to_dense_batch(fill):
assert out.size() == (4, 3, 2)


def test_to_dense_batch_disable_dynamic_shapes():
x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
batch = torch.tensor([0, 0, 1, 2, 2, 2])

with set_experimental_mode(True, 'disable_dynamic_shapes'):
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x, batch, max_num_nodes=6)
with pytest.raises(ValueError, match="'max_num_nodes' needs to be"):
out, mask = to_dense_batch(x, batch, batch_size=4)
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x)

out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6)
assert out.size() == (1, 6, 2)
assert mask.size() == (1, 6)

out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10)
assert out.size() == (3, 10, 2)
assert mask.size() == (3, 10)


@onlyFullTest
def test_to_dense_batch_jit():
@torch.jit.script
Expand Down
51 changes: 49 additions & 2 deletions torch_geometric/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List, Optional, Union
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

__experimental_flag__ = {}
__experimental_flag__ = {'disable_dynamic_shapes': False}

Options = Optional[Union[str, List[str]]]

Expand Down Expand Up @@ -77,3 +79,48 @@ def __enter__(self):
def __exit__(self, *args):
for option, value in self.previous_state.items():
__experimental_flag__[option] = value


def disable_dynamic_shapes(required_args: List[str]) -> Callable:
r"""A decorator that disables the usage of dynamic shapes for the given
arguments, i.e., it will raise an error in case :obj:`required_args` are
not passed and needs to be automatically inferred."""
def decorator(func: Callable) -> Callable:
spec = inspect.getfullargspec(func)

required_args_pos: Dict[str, int] = {}
for arg_name in required_args:
if arg_name not in spec.args:
raise ValueError(f"The function '{func}' does not have a "
f"'{arg_name}' argument")
required_args_pos[arg_name] = spec.args.index(arg_name)

num_args = len(spec.args)
num_default_args = 0 if spec.defaults is None else len(spec.defaults)
num_positional_args = num_args - num_default_args

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not is_experimental_mode_enabled('disable_dynamic_shapes'):
return func(*args, **kwargs)

for required_arg in required_args:
index = required_args_pos[required_arg]

value: Optional[Any] = None
if index < len(args):
value = args[index]
elif required_arg in kwargs:
value = kwargs[required_arg]
elif num_default_args > 0:
value = spec.defaults[index - num_positional_args]

if value is None:
raise ValueError(f"Dynamic shapes disabled. Argument "
f"'{required_arg}' needs to be set")

return func(*args, **kwargs)

return wrapper

return decorator
10 changes: 9 additions & 1 deletion torch_geometric/utils/to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch
from torch import Tensor

from torch_geometric.experimental import (
disable_dynamic_shapes,
is_experimental_mode_enabled,
)
from torch_geometric.utils import scatter


@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])
def to_dense_batch(
x: Tensor,
batch: Optional[Tensor] = None,
Expand Down Expand Up @@ -106,9 +111,12 @@ def to_dense_batch(
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
'disable_dynamic_shapes')

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif num_nodes.max() > max_num_nodes:
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True

tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
Expand Down

0 comments on commit d90842d

Please sign in to comment.