Skip to content

Commit

Permalink
Add is_compiling() checks instead of disabling extensions during `t…
Browse files Browse the repository at this point in the history
…orch_geometric.compile` (#8698)
  • Loading branch information
rusty1s authored Jan 1, 2024
1 parent a1ef7bf commit b532160
Show file tree
Hide file tree
Showing 24 changed files with 66 additions and 75 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed disabling of extension packages during `torch_geometric.compile` ([#8698](https://github.com/pyg-team/pytorch_geometric/pull/8698))

## \[2.4.0\] - 2023-10-12

### Added
Expand Down
7 changes: 2 additions & 5 deletions docs/source/advanced/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ In this tutorial, we show how to optimize your custom :pyg:`PyG` model via :meth

By default, :meth:`torch.compile` struggles to optimize a custom :pyg:`PyG` model since its underlying :class:`~torch_geometric.nn.conv.MessagePassing` interface is JIT-unfriendly due to its generality.
As such, in :pyg:`PyG 2.3`, we introduce :meth:`torch_geometric.compile`, a wrapper around :meth:`torch.compile` with the same signature.

:meth:`torch_geometric.compile` applies further optimizations to make :pyg:`PyG` models more compiler-friendly.
Specifically, it:

#. Temporarily disables the usage of the extension packages :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` during GNN execution workflows (since these are not *yet* directly optimizable by :pytorch:`PyTorch`).
From :pyg:`PyG 2.3` onwards, these packages are purely optional and not required anymore for running :pyg:`PyG` models (but :obj:`pyg_lib` may be required for graph sampling routines).
#. Converts all instances of :class:`~torch_geometric.nn.conv.MessagePassing` modules into their jittable instances (see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`).

#. Converts all instances of :class:`~torch_geometric.nn.conv.MessagePassing` modules into their jittable instances (see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`)
#. Disables generation of device asserts during fused gather/scatter calls to avoid performance impacts.

Without these adjustments, :meth:`torch.compile` may currently fail to correctly optimize your :pyg:`PyG` model.
We are working on fully relying on :meth:`torch.compile` for future releases.
Expand Down
2 changes: 0 additions & 2 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
get_random_edge_index,
onlyLinux,
withCUDA,
Expand Down Expand Up @@ -182,7 +181,6 @@ def test_hetero_conv_with_dot_syntax_node_types():

@withCUDA
@onlyLinux
@disableExtensions
@withPackage('torch>=2.1.0')
def test_compile_hetero_conv_graph_breaks(device):
import torch._dynamo as dynamo
Expand Down
4 changes: 1 addition & 3 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch_geometric.nn import MLPAggregation, SAGEConv
from torch_geometric.testing import (
assert_module,
disableExtensions,
is_full_test,
onlyLinux,
withCUDA,
Expand Down Expand Up @@ -133,7 +132,6 @@ def test_multi_aggr_sage_conv(aggr_kwargs):

@withCUDA
@onlyLinux
@disableExtensions
@withPackage('torch>=2.1.0')
def test_compile_multi_aggr_sage_conv(device):
import torch._dynamo as dynamo
Expand All @@ -154,4 +152,4 @@ def test_compile_multi_aggr_sage_conv(device):

expected = conv(x, edge_index)
out = compiled_conv(x, edge_index)
assert torch.allclose(out, expected)
assert torch.allclose(out, expected, atol=1e-6)
5 changes: 1 addition & 4 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import torch.nn.functional as F

import torch_geometric.typing
from torch_geometric.compile import to_jittable
from torch_geometric._compile import to_jittable
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
onlyFullTest,
onlyLinux,
onlyNeighborSampler,
Expand Down Expand Up @@ -208,7 +207,6 @@ def test_basic_gnn_inference(get_dataset, jk):
@withCUDA
@onlyLinux
@onlyFullTest
@disableExtensions
@withPackage('torch>=2.0.0')
def test_compile(device):
x = torch.randn(3, 8, device=device)
Expand Down Expand Up @@ -334,7 +332,6 @@ def test_trim_to_layer():

@withCUDA
@onlyLinux
@disableExtensions
@withPackage('torch>=2.1.0')
@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA])
def test_compile_graph_breaks(Model, device):
Expand Down
2 changes: 0 additions & 2 deletions test/nn/test_compile_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch_geometric
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
onlyFullTest,
onlyLinux,
withCUDA,
Expand Down Expand Up @@ -47,7 +46,6 @@ def fused_gather_scatter(x, edge_index, reduce=['sum', 'mean', 'max']):
@withCUDA
@onlyLinux
@onlyFullTest
@disableExtensions
@withPackage('torch>=2.0.0')
def test_torch_compile(device):
x = torch.randn(10, 16, device=device)
Expand Down
2 changes: 0 additions & 2 deletions test/nn/test_compile_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
onlyFullTest,
onlyLinux,
withCUDA,
Expand All @@ -30,7 +29,6 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
@withCUDA
@onlyLinux
@onlyFullTest
@disableExtensions
@withPackage('torch>=2.0.0')
@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])
def test_compile_conv(device, Conv):
Expand Down
2 changes: 0 additions & 2 deletions test/nn/test_compile_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch_geometric
from torch_geometric.testing import (
disableExtensions,
get_random_edge_index,
onlyFullTest,
onlyLinux,
Expand All @@ -30,7 +29,6 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
@withCUDA
@onlyLinux
@onlyFullTest
@disableExtensions
@withPackage('torch>2.0.0')
def test_dynamic_torch_compile(device):
conv = MySAGEConv(64, 64).to(device)
Expand Down
2 changes: 0 additions & 2 deletions test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
onlyCUDA,
onlyLinux,
withCUDA,
Expand Down Expand Up @@ -1080,7 +1079,6 @@ def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:


@onlyLinux
@disableExtensions
@withPackage('torch>=2.1.0')
def test_compile():
import torch._dynamo as dynamo
Expand Down
3 changes: 2 additions & 1 deletion test/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ def test_warn():
warn('test')


@patch('torch_geometric.warnings._is_compiling', return_value=True)
@patch('torch_geometric.is_compiling', return_value=True)
def test_no_warn_if_compiling(_):
"""No warning should be raised to avoid graph breaks when compiling."""
with warnings.catch_warnings():
warnings.simplefilter('error')
warn('test')
7 changes: 5 additions & 2 deletions test/utils/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.profile import benchmark
from torch_geometric.testing import disableExtensions, withCUDA, withPackage
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.utils import group_argsort, scatter
from torch_geometric.utils._scatter import scatter_argmax

Expand Down Expand Up @@ -95,12 +96,14 @@ def test_group_argsort(num_groups, descending, device):


@withCUDA
@disableExtensions
def test_scatter_argmax(device):
src = torch.arange(5, device=device)
index = torch.tensor([2, 2, 0, 0, 3], device=device)

old_state = torch_geometric.typing.WITH_TORCH_SCATTER
torch_geometric.typing.WITH_TORCH_SCATTER = False
argmax = scatter_argmax(src, index, dim_size=6)
torch_geometric.typing.WITH_TORCH_SCATTER = old_state
assert argmax.tolist() == [3, 5, 1, 4, 5, 5]


Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ._compile import compile, is_compiling
from .edge_index import EdgeIndex
from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .isinstance import is_torch_instance
from .debug import is_debug_enabled, debug, set_debug

import torch_geometric.utils
import torch_geometric.data
Expand All @@ -10,11 +15,6 @@
import torch_geometric.explain
import torch_geometric.profile

from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .compile import compile
from .isinstance import is_torch_instance
from .debug import is_debug_enabled, debug, set_debug
from .experimental import (is_experimental_mode_enabled, experimental_mode,
set_experimental_mode)
from .lazy_loader import LazyLoader
Expand All @@ -30,6 +30,7 @@
'get_home_dir',
'set_home_dir',
'compile',
'is_compiling',
'is_torch_instance',
'is_debug_enabled',
'debug',
Expand Down
26 changes: 11 additions & 15 deletions torch_geometric/compile.py → torch_geometric/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
"the following error: {error}")


def is_compiling() -> bool:
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
:meth:`torch.compile`.
"""
if torch_geometric.typing.WITH_PT21:
return torch._dynamo.is_compiling()
return False # pragma: no cover


def to_jittable(model: torch.nn.Module) -> torch.nn.Module:
if isinstance(model, torch_geometric.nn.MessagePassing):
try:
Expand Down Expand Up @@ -47,15 +56,12 @@ def compile(
Specifically, it
1. temporarily disables the usage of the extension packages
:obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib`
2. converts all instances of
1. converts all instances of
:class:`~torch_geometric.nn.conv.MessagePassing` modules into their
jittable instances
(see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`)
3. disables generation of device asserts during fused gather/scatter calls
2. disables generation of device asserts during fused gather/scatter calls
to avoid performance impacts
.. note::
Expand All @@ -75,16 +81,6 @@ def fn(model: torch.nn.Module) -> torch.nn.Module:

return fn

# Disable the usage of external extension packages:
# TODO (matthias) Disable only temporarily
prev_state = {
'WITH_INDEX_SORT': torch_geometric.typing.WITH_INDEX_SORT,
'WITH_TORCH_SCATTER': torch_geometric.typing.WITH_TORCH_SCATTER,
}
warnings.filterwarnings('ignore', ".*the 'torch-scatter' package.*")
for key in prev_state.keys():
setattr(torch_geometric.typing, key, False)

# Adjust the logging level of `torch.compile`:
# TODO (matthias) Disable only temporarily
prev_log_level = {
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.typing import SparseTensor

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
Expand Down Expand Up @@ -1607,12 +1608,12 @@ def _spmm(
raise ValueError(f"'matmul(..., transpose=True)' requires "
f"'{cls_name}' to be sorted by colums")

if (torch_geometric.typing.WITH_TORCH_SPARSE
if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
and other.is_cuda): # pragma: no cover
return _torch_sparse_spmm(input, other, value, reduce, transpose)

if value is not None and value.requires_grad:
if torch_geometric.typing.WITH_TORCH_SPARSE:
if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():
return _torch_sparse_spmm(input, other, value, reduce, transpose)
return _scatter_spmm(input, other, value, reduce, transpose)

Expand All @@ -1628,7 +1629,7 @@ def _spmm(
and not other.requires_grad):
return _TorchSPMM.apply(input, other, value, reduce, transpose)

if torch_geometric.typing.WITH_TORCH_SPARSE:
if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():
return _torch_sparse_spmm(input, other, value, reduce, transpose)

return _scatter_spmm(input, other, value, reduce, transpose)
Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ def jittable(self, typing: Optional[str] = None) -> 'MessagePassing':
with :meth:`forward` types based on :obj:`typing`, *e.g.*,
:obj:`"(Tensor, Optional[Tensor]) -> Tensor"`.
"""
if 'Jittable' in self.__class__.__name__:
return self

try:
from jinja2 import Template
except ImportError:
Expand Down Expand Up @@ -1025,5 +1028,4 @@ def jittable(self, typing: Optional[str] = None) -> 'MessagePassing':
cls = class_from_module_repr(cls_name, jit_module_repr)
module = cls.__new__(cls)
module.__dict__ = self.__dict__.copy()
module.jittable = None
return module
7 changes: 5 additions & 2 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch_geometric.backend
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
Expand Down Expand Up @@ -252,7 +253,8 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
use_segment_matmul = self._use_segment_matmul_heuristic_output

if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
and self.num_bases is None and x_l.is_floating_point()
and not is_compiling() and self.num_bases is None
and x_l.is_floating_point()
and isinstance(edge_index, Tensor)):

if not self.is_sorted:
Expand Down Expand Up @@ -292,7 +294,8 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
return out

def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
if torch_geometric.typing.WITH_SEGMM and edge_type_ptr is not None:
if (torch_geometric.typing.WITH_SEGMM and not is_compiling()
and edge_type_ptr is not None):
# TODO Re-weight according to edge type degree for `aggr=mean`.
return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)

Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch_geometric.backend
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort, scatter
Expand Down Expand Up @@ -284,7 +285,8 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
assert self._use_segment_matmul_heuristic_output is not None
use_segment_matmul = self._use_segment_matmul_heuristic_output

if use_segment_matmul and torch_geometric.typing.WITH_SEGMM:
if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
and not is_compiling()):
assert self.weight is not None

perm: Optional[Tensor] = None
Expand Down Expand Up @@ -422,7 +424,7 @@ def forward(
use_segment_matmul = len(x_dict) >= 10

if (use_segment_matmul and torch_geometric.typing.WITH_GMM
and not torch.jit.is_scripting()):
and not is_compiling() and not torch.jit.is_scripting()):
xs, weights, biases = [], [], []
for key, lin in self.lins.items():
if key in x_dict:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/to_hetero_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

import torch_geometric
from torch_geometric import is_compiling
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import cumsum, scatter

Expand Down Expand Up @@ -53,7 +54,7 @@ def dict_forward(
x_dict: Dict[Union[NodeType, EdgeType], Tensor],
) -> Dict[Union[NodeType, EdgeType], Tensor]:

if not torch_geometric.typing.WITH_PYG_LIB:
if not torch_geometric.typing.WITH_PYG_LIB or is_compiling():
return {
key:
F.linear(x_dict[key], self.hetero_module.weight[i].t()) +
Expand Down
Loading

0 comments on commit b532160

Please sign in to comment.