Skip to content

Commit

Permalink
Merge pull request #71 from pyt-team/update_docs_api
Browse files Browse the repository at this point in the history
updated docs api
  • Loading branch information
gbg141 authored Jun 4, 2024
2 parents 6c3e23f + 65132c4 commit bdbe647
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 14 deletions.
3 changes: 0 additions & 3 deletions docs/api/nn/backbones/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ Backbones
.. automodule:: topobenchmarkx.nn.backbones.cell.cccn
:members:

.. automodule:: topobenchmarkx.nn.backbones.cell.cin
:members:

.. automodule:: topobenchmarkx.nn.backbones.hypergraph.edgnn
:members:

Expand Down
16 changes: 9 additions & 7 deletions test/nn/test_auto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""
Module for automated testing of neural network modules.
This module contains functions and utilities to automate the testing
of various neural network modules using the NNModuleAutoTest class.
"""
import torch
from .._utils.nn_module_auto_test import NNModuleAutoTest
from topobenchmarkx.nn.backbones.cell.cccn import CCCN
from topobenchmarkx.nn.backbones.cell.cin import CWN
from topobenchmarkx.nn.backbones.hypergraph.edgnn import (
EDGNN,
MLP as edgnn_MLP,
Expand All @@ -11,6 +16,9 @@


def test_auto():
"""
Function to automate the testing of the modules.
"""
num_nodes = 8
d_feat = 12
x = torch.randn(num_nodes, 12)
Expand All @@ -31,12 +39,6 @@ def test_auto():
"forward": (x, edges_1, edges_2),
"assert_shape": (num_nodes, d_feat)
},
#{
# "module" : CWN,
# "init": (d_feat, d_feat_1, d_feat_2, hid_channels, n_layers),
# "forward": (x, x_1, x_2, edges_1, edges_1, edges_1),
# #"assert_shape": (num_nodes, d_feat)
#},
{
"module" : EDGNN,
"init": (d_feat, ),
Expand Down
2 changes: 0 additions & 2 deletions topobenchmarkx/nn/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from .cell import (
CCCN,
CWN,
)
from .hypergraph import EDGNN
from .simplicial import SCCNNCustom

__all__ = [
"CCCN",
"CWN",
"EDGNN",
"SCCNNCustom",
]
2 changes: 0 additions & 2 deletions topobenchmarkx/nn/backbones/cell/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Cell backbones."""

from .cccn import CCCN
from .cin import CWN

__all__ = [
"CCCN",
"CWN",
]

0 comments on commit bdbe647

Please sign in to comment.