Skip to content

Commit

Permalink
Merge branch 'master' into piotrc/disable_dynamic_shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jun 4, 2023
2 parents 3385f53 + c1bffa1 commit 129deee
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246))
- Added the option to override `use_segmm` selection in `HeteroLinear` ([#7474](https://github.com/pyg-team/pytorch_geometric/pull/7474))
- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
Expand Down
37 changes: 33 additions & 4 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ class HeteroLinear(torch.nn.Module):
:obj:`type_vec` is sorted. This avoids internal re-sorting of the
data and can improve runtime and memory efficiency.
(default: :obj:`False`)
use_segmm (bool, optional): If set to :obj:`True` and :obj:`pyg-lib` is
installed, this module will use the fused :obj:`segment_matmul`
kernel to parallelize the linear transformation across types. If
set to :obj:`False`, :obj:`segment_matmul` will not be used. If
left as :obj:`None` and :obj:`pyg-lib` is installed, the module
will determine heuristically whether to use :obj:`segment_matmul`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.Linear`.
Expand All @@ -207,16 +214,24 @@ class HeteroLinear(torch.nn.Module):
type vector :math:`(*)`
- **output:** features :math:`(*, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int, num_types: int,
is_sorted: bool = False, **kwargs):
def __init__(
self,
in_channels: int,
out_channels: int,
num_types: int,
is_sorted: bool = False,
use_segmm: Optional[bool] = None,
**kwargs,
):
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.num_types = num_types
self.is_sorted = is_sorted
self.use_segmm: int = -1 if use_segmm is None else int(use_segmm)
self.kwargs = kwargs
self.use_segmm: int = -1

if self.in_channels == -1:
self.weight = nn.parameter.UninitializedParameter()
self._hook = self.register_forward_pre_hook(
Expand Down Expand Up @@ -313,6 +328,13 @@ class HeteroDictLinear(torch.nn.Module):
out_channels (int): Size of each output sample.
types (List[Any], optional): The keys of the input dictionary.
(default: :obj:`None`)
use_segmm (bool, optional): If set to :obj:`True` and :obj:`pyg-lib` is
installed, this module will use the fused :obj:`segment_matmul`
kernel to parallelize the linear transformation across types. If
set to :obj:`False`, :obj:`segment_matmul` will not be used. If
left as :obj:`None` and :obj:`pyg-lib` is installed, the module
will determine heuristically whether to use :obj:`segment_matmul`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.Linear`.
"""
Expand All @@ -321,6 +343,7 @@ def __init__(
in_channels: Union[int, Dict[Any, int]],
out_channels: int,
types: Optional[Any] = None,
use_segmm: Optional[bool] = None,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -350,6 +373,7 @@ def __init__(

self.in_channels = in_channels
self.out_channels = out_channels
self.use_segmm = use_segmm
self.kwargs = kwargs

self.lins = torch.nn.ModuleDict({
Expand Down Expand Up @@ -377,8 +401,13 @@ def forward(

# Only apply fused kernel for more than 10 types, otherwise default
# back to sequential computation (which is faster for these cases).
if self.use_segmm is None:
use_segmm = len(x_dict) >= 10
else:
use_segmm = self.use_segmm

if (torch_geometric.typing.WITH_GMM and not torch.jit.is_scripting()
and len(x_dict) >= 10):
and use_segmm):
xs, weights, biases = [], [], []
for key, lin in self.lins.items():
if key in x_dict:
Expand Down

0 comments on commit 129deee

Please sign in to comment.