Skip to content

Commit

Permalink
Introduce "disable_dynamic_shapes" experimental flag. Adding its use
Browse files Browse the repository at this point in the history
into to_dense_batch function.

There are devices that do not support dynamic shapes - (compiling
and optimizing only static graphs). Possibility to set and read
"disable_dynamic_shapes" flag allows implementators to
provide static shape friendly implementations as well as reporting
user friendly messages if it is not possible to avoid using
dynamic shapes.
  • Loading branch information
piotrchmiel committed Apr 27, 2023
1 parent 5778c65 commit b397d85
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
3 changes: 1 addition & 2 deletions test/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
)


@pytest.mark.skip(reason='No experimental options available right now.')
@pytest.mark.parametrize('options', [None])
@pytest.mark.parametrize('options', ['disable_dynamic_shapes'])
def test_experimental_mode(options):
assert is_experimental_mode_enabled(options) is False
with experimental_mode(options):
Expand Down
27 changes: 27 additions & 0 deletions test/utils/test_to_dense_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import torch

from torch_geometric.experimental import set_experimental_mode
from torch_geometric.testing import is_full_test
from torch_geometric.utils import to_dense_batch

Expand Down Expand Up @@ -52,3 +54,28 @@ def test_to_dense_batch():

out, mask = to_dense_batch(x, batch, batch_size=4)
assert out.size() == (4, 3, 2)

with set_experimental_mode(True, "disable_dynamic_shapes"):
with pytest.raises(RuntimeError):
out, mask = to_dense_batch(x, batch, batch_size=4)
with pytest.raises(RuntimeError):
out, mask = to_dense_batch(x, batch, batch_size=4)

out, mask = to_dense_batch(x)

assert out.size() == (1, 6, 2)
assert out.tolist() == [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12]]]
assert mask.tolist() == [[1, 1, 1, 1, 1, 1]]

out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10)
assert out.size() == (3, 10, 2)
assert torch.equal(out[0, :3],
torch.Tensor([[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]]))
assert torch.equal(out[1, :3],
torch.Tensor([[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]))
assert torch.equal(
out[2, :3], torch.Tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]))
assert mask.tolist() == [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]
2 changes: 1 addition & 1 deletion torch_geometric/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Union

__experimental_flag__ = {}
__experimental_flag__ = {'disable_dynamic_shapes': False}

Options = Optional[Union[str, List[str]]]

Expand Down
20 changes: 17 additions & 3 deletions torch_geometric/utils/to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.experimental import is_experimental_mode_enabled
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -92,17 +93,30 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

dynamic_shapes_disabled = is_experimental_mode_enabled(
"disable_dynamic_shapes")

if batch_size is None:
batch_size = int(batch.max()) + 1
if dynamic_shapes_disabled:
raise RuntimeError(
"Dynamic shapes disabled. `batch_size` argument should be set."
)
else:
batch_size = int(batch.max()) + 1

num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,
dim_size=batch_size, reduce='sum')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

filter_nodes = False
if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif num_nodes.max() > max_num_nodes:
if dynamic_shapes_disabled:
raise RuntimeError(
"Dynamic shapes disabled. `max_num_nodes` argument should be "
"set.")
else:
max_num_nodes = int(num_nodes.max())
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True

tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
Expand Down

0 comments on commit b397d85

Please sign in to comment.