Skip to content

Commit

Permalink
Extend nightly tests to include different PyTorch versions (#7073)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 29, 2023
1 parent 306e790 commit 692a8ce
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ runs:
shell: bash

- name: Install extension packages
if: ${{ inputs.full_install == 'true' }}
if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' }}
run: |
pip install torchvision==${{ inputs.torchvision-version }} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html
shell: bash

- name: Install pyg-lib # pyg-lib is currently only available on Linux.
if: ${{ inputs.full_install == 'true' && runner.os == 'Linux' }}
if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' && runner.os == 'Linux' }}
run: |
pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html
shell: bash
12 changes: 11 additions & 1 deletion .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.8', '3.10']
torch-version: [1.13.0, 2.0.0, nightly]
include:
- torch-version: 1.13.0
torchvision-version: 0.14.0
- torch-version: 2.0.0
torchvision-version: 0.15.0
- torch-version: nightly
torchvision-version: nightly

steps:
- name: Checkout repository
Expand All @@ -25,6 +33,8 @@ jobs:
uses: ./.github/actions/setup
with:
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
torchvision-version: ${{ matrix.torchvision-version }}

- name: Install graphviz
if: ${{ runner.os == 'Linux' }}
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/latest_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ jobs:
uses: ./.github/actions/setup
with:
torch-version: nightly
full_install: false

- name: Install main package
if: steps.changed-files-specific.outputs.only_changed != 'true'
Expand Down
3 changes: 2 additions & 1 deletion test/nn/models/test_graph_unet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch

from torch_geometric.nn import GraphUNet
from torch_geometric.testing import is_full_test
from torch_geometric.testing import is_full_test, onlyLinux


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_graph_unet():
model = GraphUNet(16, 32, 8, depth=3)
out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])'
Expand Down
2 changes: 1 addition & 1 deletion test/nn/norm/test_diff_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_diff_group_norm():

if is_full_test():
jit = torch.jit.script(norm)
assert torch.alllclose(jit(x), out)
assert torch.allclose(jit(x), out)


def test_group_distance_ratio():
Expand Down
3 changes: 2 additions & 1 deletion test/nn/pool/test_asap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch

from torch_geometric.nn import ASAPooling, GCNConv, GraphConv
from torch_geometric.testing import is_full_test, onlyFullTest
from torch_geometric.testing import is_full_test, onlyFullTest, onlyLinux


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_asap():
in_channels = 16
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
Expand Down
3 changes: 2 additions & 1 deletion test/nn/pool/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from torch import Tensor

from torch_geometric.nn import radius_graph
from torch_geometric.testing import onlyFullTest
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('torch_cluster')
def test_radius_graph_jit():
class Net(torch.nn.Module):
def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions test/transforms/test_add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from torch_geometric.data import Data
from torch_geometric.testing import onlyLinux
from torch_geometric.transforms import (
AddLaplacianEigenvectorPE,
AddRandomWalkPE,
Expand Down Expand Up @@ -52,6 +53,7 @@ def test_add_laplacian_eigenvector_pe():
assert torch.allclose(pe_cluster_2, pe_cluster_2.mean())


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_add_random_walk_pe():
x = torch.randn(6, 4)
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],
Expand Down
2 changes: 2 additions & 0 deletions test/transforms/test_two_hop.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from torch_geometric.data import Data
from torch_geometric.testing import onlyLinux
from torch_geometric.transforms import TwoHop


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_two_hop():
transform = TwoHop()
assert str(transform) == 'TwoHop()'
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/nn/pool/asap.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def forward(
S = S.index_select(1, perm).to_sparse_csr()
A = S.t().to_sparse_csr() @ (A @ S)

edge_index, edge_weight = to_edge_index(A)
if edge_weight is None:
edge_index, _ = to_edge_index(A)
else:
edge_index, edge_weight = to_edge_index(A)

if self.add_self_loops:
edge_index, edge_weight = add_remaining_self_loops(
Expand Down

0 comments on commit 692a8ce

Please sign in to comment.