Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce disable_dynamic_shapes experimental flag; adding its use into to_dense_batch function #7246

Merged
merged 4 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246))
- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
Expand Down
3 changes: 1 addition & 2 deletions test/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
)


@pytest.mark.skip(reason='No experimental options available right now.')
@pytest.mark.parametrize('options', [None])
@pytest.mark.parametrize('options', ['disable_dynamic_shapes'])
def test_experimental_mode(options):
assert is_experimental_mode_enabled(options) is False
with experimental_mode(options):
Expand Down
22 changes: 22 additions & 0 deletions test/utils/test_to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import set_experimental_mode
from torch_geometric.testing import onlyFullTest
from torch_geometric.utils import to_dense_batch

Expand Down Expand Up @@ -54,6 +55,27 @@ def test_to_dense_batch(fill):
assert out.size() == (4, 3, 2)


def test_to_dense_batch_disable_dynamic_shapes():
x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
batch = torch.tensor([0, 0, 1, 2, 2, 2])

with set_experimental_mode(True, 'disable_dynamic_shapes'):
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x, batch, max_num_nodes=6)
with pytest.raises(ValueError, match="'max_num_nodes' needs to be"):
out, mask = to_dense_batch(x, batch, batch_size=4)
with pytest.raises(ValueError, match="'batch_size' needs to be set"):
out, mask = to_dense_batch(x)

out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6)
assert out.size() == (1, 6, 2)
assert mask.size() == (1, 6)

out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10)
assert out.size() == (3, 10, 2)
assert mask.size() == (3, 10)


@onlyFullTest
def test_to_dense_batch_jit():
@torch.jit.script
Expand Down
51 changes: 49 additions & 2 deletions torch_geometric/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List, Optional, Union
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

__experimental_flag__ = {}
__experimental_flag__ = {'disable_dynamic_shapes': False}

Options = Optional[Union[str, List[str]]]

Expand Down Expand Up @@ -77,3 +79,48 @@ def __enter__(self):
def __exit__(self, *args):
for option, value in self.previous_state.items():
__experimental_flag__[option] = value


def disable_dynamic_shapes(required_args: List[str]) -> Callable:
r"""A decorator that disables the usage of dynamic shapes for the given
arguments, i.e., it will raise an error in case :obj:`required_args` are
not passed and needs to be automatically inferred."""
def decorator(func: Callable) -> Callable:
spec = inspect.getfullargspec(func)

required_args_pos: Dict[str, int] = {}
for arg_name in required_args:
if arg_name not in spec.args:
raise ValueError(f"The function '{func}' does not have a "
f"'{arg_name}' argument")
required_args_pos[arg_name] = spec.args.index(arg_name)

num_args = len(spec.args)
num_default_args = 0 if spec.defaults is None else len(spec.defaults)
num_positional_args = num_args - num_default_args

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not is_experimental_mode_enabled('disable_dynamic_shapes'):
return func(*args, **kwargs)

for required_arg in required_args:
index = required_args_pos[required_arg]

value: Optional[Any] = None
if index < len(args):
value = args[index]
elif required_arg in kwargs:
value = kwargs[required_arg]
elif num_default_args > 0:
value = spec.defaults[index - num_positional_args]

if value is None:
raise ValueError(f"Dynamic shapes disabled. Argument "
f"'{required_arg}' needs to be set")

return func(*args, **kwargs)

return wrapper

return decorator
10 changes: 9 additions & 1 deletion torch_geometric/utils/to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch
from torch import Tensor

from torch_geometric.experimental import (
disable_dynamic_shapes,
is_experimental_mode_enabled,
)
from torch_geometric.utils import scatter


@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])
def to_dense_batch(
x: Tensor,
batch: Optional[Tensor] = None,
Expand Down Expand Up @@ -106,9 +111,12 @@ def to_dense_batch(
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
'disable_dynamic_shapes')

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif num_nodes.max() > max_num_nodes:
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True

tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
Expand Down