Skip to content

Commit

Permalink
Merge pull request #70 from pyt-team/ninamiolane-apiiii
Browse files Browse the repository at this point in the history
Fix API errors by fixing documentation errors in the codebase
  • Loading branch information
levtelyatnikov authored Jun 4, 2024
2 parents b8d866f + e0efde4 commit 31f10a0
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 275 deletions.
2 changes: 1 addition & 1 deletion docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The API reference gives an overview of `TopoBenchmarkX`, which consists of sever
- `utils` implements utilities to handle the training process.

.. toctree::
:maxdepth: 3
:maxdepth: 2
:caption: Packages & Modules

data/index
Expand Down
2 changes: 1 addition & 1 deletion docs/api/nn/encoders/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Encoders
.. automodule:: topobenchmarkx.nn.encoders.base
:members:

.. automodule:: topobenchmarkx.nn.encoders.all_cell_encoders
.. automodule:: topobenchmarkx.nn.encoders.all_cell_encoder
:members:
4 changes: 2 additions & 2 deletions docs/api/transforms/data_manipulations/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Data Manipulations
.. automodule:: topobenchmarkx.transforms.data_manipulations.calculate_simplicial_curvature
:members:

.. automodule:: topobenchmarkx.transforms.data_manipulations.equal_gauss_features
.. automodule:: topobenchmarkx.transforms.data_manipulations.equal_gaus_features
:members:

.. automodule:: topobenchmarkx.transforms.data_manipulations.identity_transform
Expand All @@ -17,7 +17,7 @@ Data Manipulations
.. automodule:: topobenchmarkx.transforms.data_manipulations.infere_radius_connectivity
:members:

.. automodule:: topobenchmarkx.transforms.data_manipulations.keep_only_connected_components
.. automodule:: topobenchmarkx.transforms.data_manipulations.keep_only_connected_component
:members:

.. automodule:: topobenchmarkx.transforms.data_manipulations.keep_selected_data_fields
Expand Down
6 changes: 6 additions & 0 deletions docs/api/transforms/data_transform/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
**************
Data Transform
**************

.. automodule:: topobenchmarkx.transforms.data_transform
:members:
6 changes: 0 additions & 6 deletions docs/api/transforms/data_transforms/index.rst

This file was deleted.

4 changes: 2 additions & 2 deletions docs/api/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The `transforms` module of `TopoBenchmarkX`` consists of several submodules:
1. :doc:`data_manipulations <data_manipulations/index>`
2. :doc:`feature_liftings <feature_liftings/index>`
3. :doc:`liftings <liftings/index>`
4. :doc:`data_transforms <data_transforms/index>`
4. :doc:`data_transform <data_transform/index>`

.. toctree::
:maxdepth: 2
Expand All @@ -16,4 +16,4 @@ The `transforms` module of `TopoBenchmarkX`` consists of several submodules:
data_manipulations/index
feature_liftings/index
liftings/index
data_transforms/index
data_transform/index
3 changes: 0 additions & 3 deletions docs/api/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ This module implements implements additional utilities to handle the training pr
.. automodule:: topobenchmarkx.utils.config_resolvers
:members:

.. automodule:: topobenchmarkx.utils.dataset_statistics
:members:

.. automodule:: topobenchmarkx.utils.instantiators
:members:

Expand Down
32 changes: 18 additions & 14 deletions topobenchmarkx/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ class GraphLoader(AbstractLoader):
Parameters
----------
parameters : DictConfig
Configuration parameters. The parameters must contain the following keys:
- data_dir (str): The directory where the dataset is stored.
- data_name (str): The name of the dataset.
- data_type (str): The type of the dataset.
- split_type (str): The type of split to be used. It can be "fixed", "random", or "k-fold".
If split_type is "random", the parameters must also contain the following keys:
- data_seed (int): The seed for the split.
- data_split_dir (str): The directory where the split is stored.
- train_prop (float): The proportion of the training set.
If split_type is "k-fold", the parameters must also contain the following keys:
- data_split_dir (str): The directory where the split is stored.
- k (int): The number of folds.
- data_seed (int): The seed for the split.
The parameters can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/dataset/config.yaml')`.
Configuration parameters.
**kwargs : dict
Additional keyword arguments.
Notes
-----
The parameters must contain the following keys:
- data_dir (str): The directory where the dataset is stored.
- data_name (str): The name of the dataset.
- data_type (str): The type of the dataset.
- split_type (str): The type of split to be used. It can be "fixed", "random", or "k-fold".
If split_type is "random", the parameters must also contain the following keys:
- data_seed (int): The seed for the split.
- data_split_dir (str): The directory where the split is stored.
- train_prop (float): The proportion of the training set.
If split_type is "k-fold", the parameters must also contain the following keys:
- data_split_dir (str): The directory where the split is stored.
- k (int): The number of folds.
- data_seed (int): The seed for the split.
The parameters can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/dataset/config.yaml')`.
"""

def __init__(self, parameters: DictConfig, **kwargs):
Expand Down
14 changes: 7 additions & 7 deletions topobenchmarkx/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class TBXEvaluator(AbstractEvaluator):
The task type. It can be either "classification" or "regression".
**kwargs : dict
Additional arguments for the class. The arguments depend on the task.
In "classification" scenario, the following arguments are expected:
- num_classes (int): The number of classes.
- classification_metrics (list[str]): A list of classification metrics to be computed.
In "regression" scenario, the following arguments are expected:
- regression_metrics (list[str]): A list of regression metrics to be computed.
In "classification" scenario, the following arguments are expected:
- num_classes (int): The number of classes.
- classification_metrics (list[str]): A list of classification metrics to be computed.
In "regression" scenario, the following arguments are expected:
- regression_metrics (list[str]): A list of regression metrics to be computed.
"""

def __init__(self, task, **kwargs):
Expand Down Expand Up @@ -68,9 +68,9 @@ def update(self, model_out: dict):
model_out : dict
The model output. It should contain the following keys:
- logits : torch.Tensor
The model predictions.
The model predictions.
- labels : torch.Tensor
The ground truth labels.
The ground truth labels.
Raises
------
Expand Down
14 changes: 14 additions & 0 deletions topobenchmarkx/nn/backbones/cell/cin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ class CWN(torch.nn.Module):
Dimension of hidden features.
n_layers : int
Number of CWN layers.
References
----------
.. [1] Bodnar, et al.
Weisfeiler and Lehman go cellular: CW networks.
NeurIPS 2021.
https://arxiv.org/abs/2106.12575
"""

def __init__(
Expand Down Expand Up @@ -135,6 +142,13 @@ class CWNLayer(nn.Module):
A module that updates the aggregated representations of r-cells (default: None).
eps : float, optional
A learnable parameter that scales the input features before the first convolutional layer (default: 0.01).
References
----------
.. [1] Bodnar, et al.
Weisfeiler and Lehman go cellular: CW networks.
NeurIPS 2021.
https://arxiv.org/abs/2106.12575
"""

def __init__(
Expand Down
Loading

0 comments on commit 31f10a0

Please sign in to comment.