Skip to content

Commit

Permalink
Merge pull request #114 from geometric-intelligence/topology_learning
Browse files Browse the repository at this point in the history
Topology learning
  • Loading branch information
levtelyatnikov authored Nov 25, 2024
2 parents 6b5c349 + cf84728 commit fb4f433
Show file tree
Hide file tree
Showing 32 changed files with 1,022 additions and 85 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import topobenchmarkx

__all__ = [
"topobenchmarkx",
"configs",
"test",
"topobenchmarkx",
]

__version__ = "0.0.1"
2 changes: 2 additions & 0 deletions configs/loss/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
_target_: topobenchmarkx.loss.TBXLoss

dataset_loss:
task: ${dataset.parameters.task}
loss_type: ${dataset.parameters.loss_type}

modules_losses: # Collect model losses
feature_encoder: ${oc.select:model.feature_encoder.loss,null}
backbone: ${oc.select:model.backbone.loss,null}
Expand Down
41 changes: 41 additions & 0 deletions configs/model/graph/gcn_dgm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
_target_: topobenchmarkx.model.TBXModel

model_name: gcn
model_domain: graph

feature_encoder:
_target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name}
encoder_name: DGMStructureFeatureEncoder
in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}}
out_channels: 64
proj_dropout: 0.0
loss:
_target_: topobenchmarkx.loss.model.DGMLoss
loss_weight: 10

backbone:
_target_: torch_geometric.nn.models.GCN
in_channels: ${model.feature_encoder.out_channels}
hidden_channels: ${model.feature_encoder.out_channels}
num_layers: 1
dropout: 0.0
act: relu

backbone_wrapper:
_target_: topobenchmarkx.nn.wrappers.GNNWrapper
_partial_: true
wrapper_name: GNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}}

readout:
_target_: topobenchmarkx.nn.readouts.${model.readout.readout_name}
readout_name: NoReadOut # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider
hidden_dim: ${model.feature_encoder.out_channels}
out_channels: ${dataset.parameters.num_classes}
task_level: ${dataset.parameters.task_level}
pooling_type: sum

# compile model for faster training with pytorch 2.0
compile: false
4 changes: 2 additions & 2 deletions configs/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
defaults:
- _self_
- dataset: graph/cocitation_cora
- model: cell/topotune
- transforms: ${get_default_transform:${dataset},${model}} #no_transform
- model: graph/gcn_dgm
- transforms: ${get_default_transform:${dataset},${model}} #tree #${get_default_transform:${dataset},${model}} #no_transform
- optimizer: default
- loss: default
- evaluator: default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "InfereKNNConnectivity"
transform_type: "data manipulation"
args:
k: 5 # Number of nearest neighbors to consider
k: 40 # Number of nearest neighbors to consider
cosine: false # If true, will use the cosine distance instead of euclidean distance to find nearest neighbors. (Note: option equal to true gives an error)
loop: false # If True, the graph will contain self-loops. Note: using true and then siplicial lifting leads to an error because there are selfedges that cause simplex to have duplicated node.

3 changes: 3 additions & 0 deletions configs/transforms/data_manipulations/infere_tree.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "InferTreeConnectivity"
#split_params: ${dataset.split_params}
2 changes: 2 additions & 0 deletions configs/transforms/knn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/data_manipulations@knn: infere_knn_connectivity
2 changes: 2 additions & 0 deletions configs/transforms/tree.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/data_manipulations@tree: infere_tree
14 changes: 0 additions & 14 deletions index.html

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "TopoBenchmark"
dynamic = ["version"]
authors = [
{name = "PyT-Team Authors", email = "tlscabinet@gmail.com"}
{name = "Topological Intelligence Team Authors", email = "tlscabinet@gmail.com"}
]
readme = "README.md"
description = "Topological Deep Learning"
Expand Down
6 changes: 3 additions & 3 deletions topobenchmarkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

__all__ = [
"data",
"dataloader",
"evaluator",
"initialize_hydra",
"loss",
"model",
"nn",
"transforms",
"utils",
"dataloader",
"model",
"initialize_hydra",
]


Expand Down
2 changes: 1 addition & 1 deletion topobenchmarkx/data/loaders/hypergraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def is_loader_class(obj: Any) -> bool:
return (
inspect.isclass(obj)
and not obj.__name__.startswith("_")
and "HypergraphDatasetLoader" in obj.__name__
and "DatasetLoader" in obj.__name__
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion topobenchmarkx/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .dataload_dataset import DataloadDataset
from .dataloader import TBXDataloader

__all__ = ["TBXDataloader", "DataloadDataset"]
__all__ = ["DataloadDataset", "TBXDataloader"]
2 changes: 1 addition & 1 deletion topobenchmarkx/evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .evaluator import TBXEvaluator # noqa: E402

__all__ = [
"METRICS",
"AbstractEvaluator",
"TBXEvaluator",
"METRICS",
]
107 changes: 96 additions & 11 deletions topobenchmarkx/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,102 @@
"""This module implements the loss functions for the topobenchmarkx package."""

from .base import AbstractLoss
from .loss import TBXLoss
import importlib
import inspect
import sys
from pathlib import Path
from typing import Any

# ... import other readout classes here
# For example:
# from topobenchmarkx.loss.other_loss_1 import OtherLoss1
# from topobenchmarkx.loss.other_loss_2 import OtherLoss2

class LoadManager:
"""Manages automatic discovery and registration of loss classes."""

@staticmethod
def is_encoder_class(obj: Any) -> bool:
"""Check if an object is a valid loss class.
Parameters
----------
obj : Any
The object to check if it's a valid loss class.
Returns
-------
bool
True if the object is a valid loss class (non-private class
with 'FeatureEncoder' in name), False otherwise.
"""
try:
from .base import AbstractLoss

return (
inspect.isclass(obj)
and not obj.__name__.startswith("_")
and issubclass(obj, AbstractLoss)
and obj is not AbstractLoss
)
except ImportError:
return False

@classmethod
def discover_losses(cls, package_path: str) -> dict[str, type]:
"""Dynamically discover all loss classes in the package.
Parameters
----------
package_path : str
Path to the package's __init__.py file.
Returns
-------
Dict[str, Type]
Dictionary mapping loss class names to their corresponding class objects.
"""
losses = {}
package_dir = Path(package_path).parent

# Add parent directory to sys.path to ensure imports work
parent_dir = str(package_dir.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)

# Iterate through all .py files in the directory
for file_path in package_dir.glob("*.py"):
if file_path.stem == "__init__":
continue

try:
# Use importlib to safely import the module
module_name = f"{package_dir.stem}.{file_path.stem}"
module = importlib.import_module(module_name)

# Find all loss classes in the module
for name, obj in inspect.getmembers(module):
if (
cls.is_encoder_class(obj)
and obj.__module__ == module.__name__
):
losses[name] = obj # noqa: PERF403

except ImportError as e:
print(f"Could not import module {module_name}: {e}")

return losses


# Dynamically create the loss manager and discover losses
manager = LoadManager()
LOSSES = manager.discover_losses(__file__)
LOSSES_list = list(LOSSES.keys())

# Combine manual and discovered losses
all_encoders = {**LOSSES}

# Generate __all__
__all__ = [
"AbstractLoss",
"TBXLoss",
# "OtherLoss1",
# "OtherLoss2",
# ... add other loss classes here
"LOSSES",
"LOSSES_list",
*list(all_encoders.keys()),
]

# Update locals for direct import
locals().update(all_encoders)
110 changes: 107 additions & 3 deletions topobenchmarkx/loss/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,111 @@
"""Init file for custom loss module."""
"""This module implements the loss functions for the topobenchmarkx package."""

from .DatasetLoss import DatasetLoss
import importlib
import inspect
import sys
from pathlib import Path
from typing import Any


class LoadManager:
"""Manages automatic discovery and registration of loss classes."""

@staticmethod
def is_encoder_class(obj: Any) -> bool:
"""Check if an object is a valid loss class.
Parameters
----------
obj : Any
The object to check if it's a valid loss class.
Returns
-------
bool
True if the object is a valid loss class (non-private class
with 'FeatureEncoder' in name), False otherwise.
"""
try:
from ..base import AbstractLoss

return (
inspect.isclass(obj)
and not obj.__name__.startswith("_")
and issubclass(obj, AbstractLoss)
and obj is not AbstractLoss
)
except ImportError:
return False

@classmethod
def discover_losses(cls, package_path: str) -> dict[str, type]:
"""Dynamically discover all loss classes in the package.
Parameters
----------
package_path : str
Path to the package's __init__.py file.
Returns
-------
Dict[str, Type]
Dictionary mapping loss class names to their corresponding class objects.
"""
losses = {}
package_dir = Path(package_path).parent

# Add parent directory to sys.path to ensure imports work
parent_dir = str(package_dir.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)

# Iterate through all .py files in the directory
for file_path in package_dir.glob("*.py"):
if file_path.stem == "__init__":
continue

try:
# Use importlib to safely import the module
module_name = f"{package_dir.stem}.{file_path.stem}"
module = importlib.import_module(module_name)

# Find all loss classes in the module
for name, obj in inspect.getmembers(module):
if (
cls.is_encoder_class(obj)
and obj.__module__ == module.__name__
):
losses[name] = obj # noqa: PERF403

except ImportError as e:
print(f"Could not import module {module_name}: {e}")

return losses


# Dynamically create the loss manager and discover losses
manager = LoadManager()
LOSSES = manager.discover_losses(__file__)
LOSSES_list = list(LOSSES.keys())

# Combine manual and discovered losses
all_encoders = {**LOSSES}

# Generate __all__
__all__ = [
"DatasetLoss",
"LOSSES",
"LOSSES_list",
*list(all_encoders.keys()),
]

# Update locals for direct import
locals().update(all_encoders)


# """Init file for custom loss module."""

# from .DatasetLoss import DatasetLoss

# __all__ = [
# "DatasetLoss",
# ]
Loading

0 comments on commit fb4f433

Please sign in to comment.