Skip to content

Commit

Permalink
Enable pyright (#6415)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 13, 2023
1 parent 33b7e98 commit 3a3df12
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 41 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@ on: # yamllint disable-line rule:truthy

jobs:

pyright:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Setup packages
uses: ./.github/actions/setup

- name: Install pyright
run: |
pip install pyright
- name: Run pyright
continue-on-error: true
run: |
pyright
pylint:
runs-on: ubuntu-latest

Expand Down
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ repos:
rev: v4.4.0
hooks:
- id: end-of-file-fixer
name: End-of-file fixer
- id: trailing-whitespace
name: Remove trailing whitespaces
- id: check-yaml
name: Check YAML
exclude: |
(?x)^(
conda/pytorch-geometric/meta.yaml|
Expand All @@ -15,6 +18,7 @@ repos:
rev: v1.28.0
hooks:
- id: yamllint
name: Lint YAML
args: [-c=.yamllint.yml]

- repo: https://github.com/regebro/pyroma
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added basic `pyright` type checker support ([#6415](https://github.com/pyg-team/pytorch_geometric/pull/6415))
- Added a new external resource for link prediction ([#6396](https://github.com/pyg-team/pytorch_geometric/pull/6396))
- Added `CaptumExplainer` ([#6383](https://github.com/pyg-team/pytorch_geometric/pull/6383), [#6387](https://github.com/pyg-team/pytorch_geometric/pull/6387))
- Added support for custom `HeteroData` mini-batch class in remote backends ([#6377](https://github.com/pyg-team/pytorch_geometric/pull/6377))
Expand Down
7 changes: 7 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"ignore": [
],
"include": [
"torch_geometric/utils/*"
]
}
5 changes: 4 additions & 1 deletion torch_geometric/utils/assortativity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import Tensor

from torch_geometric.typing import Adj, SparseTensor
from torch_geometric.utils import coalesce, degree
Expand Down Expand Up @@ -29,8 +30,10 @@ def assortativity(edge_index: Adj) -> float:
-0.666667640209198
"""
if isinstance(edge_index, SparseTensor):
row, col, _ = edge_index.coo()
adj: SparseTensor = edge_index
row, col, _ = adj.coo()
else:
assert isinstance(edge_index, Tensor)
row, col = edge_index

device = row.device
Expand Down
16 changes: 10 additions & 6 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def add_self_loops(edge_index, edge_attr=None, fill_value=None,


def add_self_loops(
edge_index: Tensor, edge_attr: OptTensor = None,
fill_value: Union[float, Tensor, str] = None,
num_nodes: Optional[int] = None) -> Tuple[Tensor, OptTensor]:
edge_index: Tensor,
edge_attr: OptTensor = None,
fill_value: Optional[Union[float, Tensor, str]] = None,
num_nodes: Optional[int] = None,
) -> Tuple[Tensor, OptTensor]:
r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
:math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
In case the graph is weighted or has multi-dimensional edge features
Expand Down Expand Up @@ -234,9 +236,11 @@ def add_remaining_self_loops(edge_index, edge_attr=None, fill_value=None,


def add_remaining_self_loops(
edge_index: Tensor, edge_attr: OptTensor = None,
fill_value: Union[float, Tensor, str] = None,
num_nodes: Optional[int] = None) -> Tuple[Tensor, OptTensor]:
edge_index: Tensor,
edge_attr: OptTensor = None,
fill_value: Optional[Union[float, Tensor, str]] = None,
num_nodes: Optional[int] = None,
) -> Tuple[Tensor, OptTensor]:
r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node
:math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
In case the graph is weighted or has multi-dimensional edge features
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/utils/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def to_nested_tensor(
(default: :obj:`None`)
"""
if ptr is not None:
sizes = ptr[1:] - ptr[:-1]
sizes: List[int] = sizes.tolist()
offsets = ptr[1:] - ptr[:-1]
sizes: List[int] = offsets.tolist()
xs = list(torch.split(x, sizes, dim=0))
elif batch is not None:
sizes = scatter(torch.ones_like(batch), batch, dim_size=batch_size)
sizes: List[int] = sizes.tolist()
offsets = scatter(torch.ones_like(batch), batch, dim_size=batch_size)
sizes: List[int] = offsets.tolist()
xs = list(torch.split(x, sizes, dim=0))
else:
xs = [x]
Expand Down
57 changes: 28 additions & 29 deletions torch_geometric/utils/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.typing import PairTensor
from torch_geometric.typing import OptTensor, PairTensor

from .mask import index_to_mask
from .num_nodes import maybe_num_nodes
Expand Down Expand Up @@ -40,11 +40,11 @@ def get_num_hops(model: torch.nn.Module) -> int:
def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
num_nodes: Optional[int] = None,
return_edge_mask: bool = False,
) -> Tuple[Tensor, Tensor]:
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]:
r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)`
containing the nodes in :obj:`subset`.
Expand Down Expand Up @@ -114,11 +114,11 @@ def subgraph(
def bipartite_subgraph(
subset: Union[PairTensor, Tuple[List[int], List[int]]],
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
size: Tuple[int, int] = None,
size: Optional[Tuple[int, int]] = None,
return_edge_mask: bool = False,
) -> Tuple[Tensor, Tensor]:
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]:
r"""Returns the induced subgraph of the bipartite graph
:obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`.
Expand Down Expand Up @@ -161,35 +161,34 @@ def bipartite_subgraph(

device = edge_index.device

if isinstance(subset[0], (list, tuple)):
subset = (torch.tensor(subset[0], dtype=torch.long, device=device),
torch.tensor(subset[1], dtype=torch.long, device=device))
src_subset, dst_subset = subset
if not isinstance(src_subset, Tensor):
src_subset = torch.tensor(src_subset, dtype=torch.long, device=device)
if not isinstance(dst_subset, Tensor):
dst_subset = torch.tensor(dst_subset, dtype=torch.long, device=device)

if subset[0].dtype == torch.bool or subset[0].dtype == torch.uint8:
size = subset[0].size(0), subset[1].size(0)
else:
if size is None:
size = (maybe_num_nodes(edge_index[0]),
maybe_num_nodes(edge_index[1]))
subset = (index_to_mask(subset[0], size=size[0]),
index_to_mask(subset[1], size=size[1]))
if src_subset.dtype != torch.bool:
src_size = int(edge_index[0].max()) + 1 if size is None else size[0]
src_subset = index_to_mask(src_subset, size=src_size)
dst_size = int(edge_index[1].max()) + 1 if size is None else size[1]
dst_subset = index_to_mask(dst_subset, size=dst_size)

node_mask = subset
edge_mask = node_mask[0][edge_index[0]] & node_mask[1][edge_index[1]]
# node_mask = subset
edge_mask = src_subset[edge_index[0]] & dst_subset[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

if relabel_nodes:
node_idx_i = torch.zeros(node_mask[0].size(0), dtype=torch.long,
device=device)
node_idx_j = torch.zeros(node_mask[1].size(0), dtype=torch.long,
device=device)
node_idx_i[node_mask[0]] = torch.arange(node_mask[0].sum().item(),
device=device)
node_idx_j[node_mask[1]] = torch.arange(node_mask[1].sum().item(),
device=device)
edge_index = torch.stack(
[node_idx_i[edge_index[0]], node_idx_j[edge_index[1]]])
node_idx_i = edge_index.new_zeros(src_subset.size(0))
node_idx_j = edge_index.new_zeros(dst_subset.size(0))
node_idx_i[src_subset] = torch.arange(int(src_subset.sum()),
device=node_idx_i.device)
node_idx_j[dst_subset] = torch.arange(int(dst_subset.sum()),
device=node_idx_j.device)
edge_index = torch.stack([
node_idx_i[edge_index[0]],
node_idx_j[edge_index[1]],
], dim=0)

if return_edge_mask:
return edge_index, edge_attr, edge_mask
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/to_dense_adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def to_dense_adj(
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]

if max_num_nodes is None:
max_num_nodes = num_nodes.max().item()
max_num_nodes = int(num_nodes.max())

elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
Expand Down

0 comments on commit 3a3df12

Please sign in to comment.