-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #114 from geometric-intelligence/topology_learning
Topology learning
- Loading branch information
Showing
32 changed files
with
1,022 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
import topobenchmarkx | ||
|
||
__all__ = [ | ||
"topobenchmarkx", | ||
"configs", | ||
"test", | ||
"topobenchmarkx", | ||
] | ||
|
||
__version__ = "0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
_target_: topobenchmarkx.model.TBXModel | ||
|
||
model_name: gcn | ||
model_domain: graph | ||
|
||
feature_encoder: | ||
_target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name} | ||
encoder_name: DGMStructureFeatureEncoder | ||
in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}} | ||
out_channels: 64 | ||
proj_dropout: 0.0 | ||
loss: | ||
_target_: topobenchmarkx.loss.model.DGMLoss | ||
loss_weight: 10 | ||
|
||
backbone: | ||
_target_: torch_geometric.nn.models.GCN | ||
in_channels: ${model.feature_encoder.out_channels} | ||
hidden_channels: ${model.feature_encoder.out_channels} | ||
num_layers: 1 | ||
dropout: 0.0 | ||
act: relu | ||
|
||
backbone_wrapper: | ||
_target_: topobenchmarkx.nn.wrappers.GNNWrapper | ||
_partial_: true | ||
wrapper_name: GNNWrapper | ||
out_channels: ${model.feature_encoder.out_channels} | ||
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} | ||
|
||
readout: | ||
_target_: topobenchmarkx.nn.readouts.${model.readout.readout_name} | ||
readout_name: NoReadOut # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown | ||
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider | ||
hidden_dim: ${model.feature_encoder.out_channels} | ||
out_channels: ${dataset.parameters.num_classes} | ||
task_level: ${dataset.parameters.task_level} | ||
pooling_type: sum | ||
|
||
# compile model for faster training with pytorch 2.0 | ||
compile: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
_target_: topobenchmarkx.transforms.data_transform.DataTransform | ||
transform_name: "InferTreeConnectivity" | ||
#split_params: ${dataset.split_params} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
defaults: | ||
- /transforms/data_manipulations@knn: infere_knn_connectivity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
defaults: | ||
- /transforms/data_manipulations@tree: infere_tree |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,102 @@ | ||
"""This module implements the loss functions for the topobenchmarkx package.""" | ||
|
||
from .base import AbstractLoss | ||
from .loss import TBXLoss | ||
import importlib | ||
import inspect | ||
import sys | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
# ... import other readout classes here | ||
# For example: | ||
# from topobenchmarkx.loss.other_loss_1 import OtherLoss1 | ||
# from topobenchmarkx.loss.other_loss_2 import OtherLoss2 | ||
|
||
class LoadManager: | ||
"""Manages automatic discovery and registration of loss classes.""" | ||
|
||
@staticmethod | ||
def is_encoder_class(obj: Any) -> bool: | ||
"""Check if an object is a valid loss class. | ||
Parameters | ||
---------- | ||
obj : Any | ||
The object to check if it's a valid loss class. | ||
Returns | ||
------- | ||
bool | ||
True if the object is a valid loss class (non-private class | ||
with 'FeatureEncoder' in name), False otherwise. | ||
""" | ||
try: | ||
from .base import AbstractLoss | ||
|
||
return ( | ||
inspect.isclass(obj) | ||
and not obj.__name__.startswith("_") | ||
and issubclass(obj, AbstractLoss) | ||
and obj is not AbstractLoss | ||
) | ||
except ImportError: | ||
return False | ||
|
||
@classmethod | ||
def discover_losses(cls, package_path: str) -> dict[str, type]: | ||
"""Dynamically discover all loss classes in the package. | ||
Parameters | ||
---------- | ||
package_path : str | ||
Path to the package's __init__.py file. | ||
Returns | ||
------- | ||
Dict[str, Type] | ||
Dictionary mapping loss class names to their corresponding class objects. | ||
""" | ||
losses = {} | ||
package_dir = Path(package_path).parent | ||
|
||
# Add parent directory to sys.path to ensure imports work | ||
parent_dir = str(package_dir.parent) | ||
if parent_dir not in sys.path: | ||
sys.path.insert(0, parent_dir) | ||
|
||
# Iterate through all .py files in the directory | ||
for file_path in package_dir.glob("*.py"): | ||
if file_path.stem == "__init__": | ||
continue | ||
|
||
try: | ||
# Use importlib to safely import the module | ||
module_name = f"{package_dir.stem}.{file_path.stem}" | ||
module = importlib.import_module(module_name) | ||
|
||
# Find all loss classes in the module | ||
for name, obj in inspect.getmembers(module): | ||
if ( | ||
cls.is_encoder_class(obj) | ||
and obj.__module__ == module.__name__ | ||
): | ||
losses[name] = obj # noqa: PERF403 | ||
|
||
except ImportError as e: | ||
print(f"Could not import module {module_name}: {e}") | ||
|
||
return losses | ||
|
||
|
||
# Dynamically create the loss manager and discover losses | ||
manager = LoadManager() | ||
LOSSES = manager.discover_losses(__file__) | ||
LOSSES_list = list(LOSSES.keys()) | ||
|
||
# Combine manual and discovered losses | ||
all_encoders = {**LOSSES} | ||
|
||
# Generate __all__ | ||
__all__ = [ | ||
"AbstractLoss", | ||
"TBXLoss", | ||
# "OtherLoss1", | ||
# "OtherLoss2", | ||
# ... add other loss classes here | ||
"LOSSES", | ||
"LOSSES_list", | ||
*list(all_encoders.keys()), | ||
] | ||
|
||
# Update locals for direct import | ||
locals().update(all_encoders) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,111 @@ | ||
"""Init file for custom loss module.""" | ||
"""This module implements the loss functions for the topobenchmarkx package.""" | ||
|
||
from .DatasetLoss import DatasetLoss | ||
import importlib | ||
import inspect | ||
import sys | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
|
||
class LoadManager: | ||
"""Manages automatic discovery and registration of loss classes.""" | ||
|
||
@staticmethod | ||
def is_encoder_class(obj: Any) -> bool: | ||
"""Check if an object is a valid loss class. | ||
Parameters | ||
---------- | ||
obj : Any | ||
The object to check if it's a valid loss class. | ||
Returns | ||
------- | ||
bool | ||
True if the object is a valid loss class (non-private class | ||
with 'FeatureEncoder' in name), False otherwise. | ||
""" | ||
try: | ||
from ..base import AbstractLoss | ||
|
||
return ( | ||
inspect.isclass(obj) | ||
and not obj.__name__.startswith("_") | ||
and issubclass(obj, AbstractLoss) | ||
and obj is not AbstractLoss | ||
) | ||
except ImportError: | ||
return False | ||
|
||
@classmethod | ||
def discover_losses(cls, package_path: str) -> dict[str, type]: | ||
"""Dynamically discover all loss classes in the package. | ||
Parameters | ||
---------- | ||
package_path : str | ||
Path to the package's __init__.py file. | ||
Returns | ||
------- | ||
Dict[str, Type] | ||
Dictionary mapping loss class names to their corresponding class objects. | ||
""" | ||
losses = {} | ||
package_dir = Path(package_path).parent | ||
|
||
# Add parent directory to sys.path to ensure imports work | ||
parent_dir = str(package_dir.parent) | ||
if parent_dir not in sys.path: | ||
sys.path.insert(0, parent_dir) | ||
|
||
# Iterate through all .py files in the directory | ||
for file_path in package_dir.glob("*.py"): | ||
if file_path.stem == "__init__": | ||
continue | ||
|
||
try: | ||
# Use importlib to safely import the module | ||
module_name = f"{package_dir.stem}.{file_path.stem}" | ||
module = importlib.import_module(module_name) | ||
|
||
# Find all loss classes in the module | ||
for name, obj in inspect.getmembers(module): | ||
if ( | ||
cls.is_encoder_class(obj) | ||
and obj.__module__ == module.__name__ | ||
): | ||
losses[name] = obj # noqa: PERF403 | ||
|
||
except ImportError as e: | ||
print(f"Could not import module {module_name}: {e}") | ||
|
||
return losses | ||
|
||
|
||
# Dynamically create the loss manager and discover losses | ||
manager = LoadManager() | ||
LOSSES = manager.discover_losses(__file__) | ||
LOSSES_list = list(LOSSES.keys()) | ||
|
||
# Combine manual and discovered losses | ||
all_encoders = {**LOSSES} | ||
|
||
# Generate __all__ | ||
__all__ = [ | ||
"DatasetLoss", | ||
"LOSSES", | ||
"LOSSES_list", | ||
*list(all_encoders.keys()), | ||
] | ||
|
||
# Update locals for direct import | ||
locals().update(all_encoders) | ||
|
||
|
||
# """Init file for custom loss module.""" | ||
|
||
# from .DatasetLoss import DatasetLoss | ||
|
||
# __all__ = [ | ||
# "DatasetLoss", | ||
# ] |
Oops, something went wrong.