Skip to content

Commit

Permalink
Introduce "disable_dynamic_shapes" experimental flag. Adding its use
Browse files Browse the repository at this point in the history
into to_dense_batch function.

There are devices that do not support dynamic shapes - (compiling and
optimizing only static graphs). The ability to set and read the
"disable_dynamic_shapes" flag allows implementors to provide static
shape-friendly implementations and report user-friendly messages if it
is impossible to avoid using dynamic shapes.
  • Loading branch information
piotrchmiel committed May 16, 2023
1 parent 50cbd43 commit 0a048d8
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.jit.script` within `MessagePassing` layers without `torch_sparse` being installed ([#7061](https://github.com/pyg-team/pytorch_geometric/pull/7061), [#7062](https://github.com/pyg-team/pytorch_geometric/pull/7062))
- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037))
- Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026))

- Added `disable_dynamic_shape` exeperimental flag ([#7246]https://github.com/pyg-team/pytorch_geometric/pull/7246)
### Changed

- Fixed a bug in which inputs where modified in-place in `to_hetero_with_bases` ([#7363](https://github.com/pyg-team/pytorch_geometric/pull/7363))
Expand All @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043))
- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))
- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))

- Added usage `disable_dynamic_shape` exeperimental flag in `to_dense_batch` function ([#7246]https://github.com/pyg-team/pytorch_geometric/pull/7246)
### Removed

- Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117))
Expand Down
3 changes: 1 addition & 2 deletions test/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
)


@pytest.mark.skip(reason='No experimental options available right now.')
@pytest.mark.parametrize('options', [None])
@pytest.mark.parametrize('options', ['disable_dynamic_shapes'])
def test_experimental_mode(options):
assert is_experimental_mode_enabled(options) is False
with experimental_mode(options):
Expand Down
28 changes: 28 additions & 0 deletions test/utils/test_to_dense_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import torch

from torch_geometric.experimental import set_experimental_mode
from torch_geometric.testing import is_full_test
from torch_geometric.utils import to_dense_batch

Expand Down Expand Up @@ -52,3 +54,29 @@ def test_to_dense_batch():

out, mask = to_dense_batch(x, batch, batch_size=4)
assert out.size() == (4, 3, 2)

with set_experimental_mode(True, "disable_dynamic_shapes"):
with pytest.raises(ValueError):
out, mask = to_dense_batch(x, batch, max_num_nodes=6)
with pytest.raises(ValueError):
out, mask = to_dense_batch(x, batch, batch_size=4)
with pytest.raises(ValueError):
out, mask = to_dense_batch(x)

out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6)
assert out.size() == (1, 6, 2)
assert out.tolist() == [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12]]]
assert mask.tolist() == [[1, 1, 1, 1, 1, 1]]

out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10)
assert out.size() == (3, 10, 2)
assert torch.equal(out[0, :3],
torch.Tensor([[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]]))
assert torch.equal(out[1, :3],
torch.Tensor([[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]))
assert torch.equal(
out[2, :3], torch.Tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]))
assert mask.tolist() == [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]
60 changes: 59 additions & 1 deletion torch_geometric/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
import inspect
from typing import List, Optional, Union

__experimental_flag__ = {}
__experimental_flag__ = {'disable_dynamic_shapes': False}

Options = Optional[Union[str, List[str]]]

Expand Down Expand Up @@ -77,3 +79,59 @@ def __enter__(self):
def __exit__(self, *args):
for option, value in self.previous_state.items():
__experimental_flag__[option] = value


def disable_dynamic_shapes(required_args: Union[list, tuple]):

if not required_args:
raise ValueError('required_args list cannot be empty')

def decorator(func):
func_spec = inspect.getfullargspec(func)

required_args_pos = {}

for arg_name in required_args:
if arg_name not in func_spec.args:
raise ValueError(
f'function {func} does not take a {arg_name} argument')
required_args_pos[arg_name] = func_spec.args.index(arg_name)

num_args = len(func_spec.args)
num_default_args = 0 if func_spec.defaults is None else len(
func_spec.defaults)
num_positional_args = num_args - num_default_args

@functools.wraps(func)
def wrapper(*args, **kwargs):
dynamic_shapes_disabled = is_experimental_mode_enabled(
"disable_dynamic_shapes")
if dynamic_shapes_disabled:
num_passed_args = len(args)

def validate_param(param_name, value):
if value is None:
raise ValueError(
"Dynamic shapes disabled. Mandatory parameter "
f"`{param_name}` cannot be None.")

for param_name in required_args:
value = None
index = required_args_pos[param_name]
if index < num_passed_args:
value = args[index]
elif param_name in kwargs:
value = kwargs[param_name]
elif num_default_args:
defaults_index = index - num_positional_args

if defaults_index < num_default_args:
value = func_spec.defaults[defaults_index]

validate_param(param_name, value)

return func(*args, **kwargs)

return wrapper

return decorator
10 changes: 9 additions & 1 deletion torch_geometric/utils/to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch
from torch import Tensor

from torch_geometric.experimental import (
disable_dynamic_shapes,
is_experimental_mode_enabled,
)
from torch_geometric.utils import scatter


@disable_dynamic_shapes(required_args=["batch_size", "max_num_nodes"])
def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
fill_value: float = 0., max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -100,9 +105,12 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
"disable_dynamic_shapes")

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif num_nodes.max() > max_num_nodes:
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True

tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
Expand Down

0 comments on commit 0a048d8

Please sign in to comment.