Skip to content

Commit

Permalink
Merge pull request #110 from geometric-intelligence/feat-mantra
Browse files Browse the repository at this point in the history
Mantra dataset
  • Loading branch information
levtelyatnikov authored Dec 24, 2024
2 parents 04f9665 + 45a9bb3 commit 58083bd
Show file tree
Hide file tree
Showing 24 changed files with 757 additions and 40 deletions.
40 changes: 40 additions & 0 deletions configs/dataset/simplicial/mantra_betti_numbers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
loader:
_target_: topobenchmark.data.loaders.MantraSimplicialDatasetLoader
parameters:
data_domain: simplicial
data_type: topological
data_name: MANTRA
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
manifold_dim: 3
version: "v0.0.5"
task_variable: "betti_numbers" # Options: ['name', 'genus', 'orientable'] To use 'torsion_coefficients', 'betti_numbers' fix multilabel multiclass issue
model_domain: ${model.model_domain}

# Data definition
parameters:
# In the case of higher-order datasets we have multiple feature dimentions
num_features: [1,1,1]
#num_classes: 2 # Num classes depents on the task_variable

# Dataset parameters
# task: classification # TODO: adapt pipeline to support multilabel classification
# loss_type: cross_entropy # TODO: adapt pipeline to support multilabel classification
# monitor_metric: accuracy # TODO: adapt pipeline to support multilabel classification
task_level: graph
data_seed: 0

#splits
split_params:
learning_setting: inductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting

# Dataloader parameters
dataloader_params:
batch_size: 5
num_workers: 0
pin_memory: False
persistent_workers: False
40 changes: 40 additions & 0 deletions configs/dataset/simplicial/mantra_genus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
loader:
_target_: topobenchmark.data.loaders.MantraSimplicialDatasetLoader
parameters:
data_domain: simplicial
data_type: topological
data_name: MANTRA
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
manifold_dim: 2
version: "v0.0.5"
task_variable: "genus" # Options: ['name', 'genus', 'orientable'] To use 'torsion_coefficients', 'betti_numbers' fix multilabel multiclass issue
model_domain: ${model.model_domain}

# Data definition
parameters:
# In the case of higher-order datasets we have multiple feature dimentions
num_features: [1,1,1]
num_classes: 8 # Num classes depents on the task_variable

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
data_seed: 0

#splits
split_params:
learning_setting: inductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting

# Dataloader parameters
dataloader_params:
batch_size: 5
num_workers: 0
pin_memory: False
persistent_workers: False
40 changes: 40 additions & 0 deletions configs/dataset/simplicial/mantra_name.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
loader:
_target_: topobenchmark.data.loaders.MantraSimplicialDatasetLoader
parameters:
data_domain: simplicial
data_type: topological
data_name: MANTRA
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
manifold_dim: 2
version: "v0.0.5"
task_variable: "name" # Options: ['name', 'genus', 'orientable'] To use 'torsion_coefficients', 'betti_numbers' fix multilabel multiclass issue
model_domain: ${model.model_domain}

# Data definition
parameters:
# In the case of higher-order datasets we have multiple feature dimentions
num_features: [1,1,1]
num_classes: 8 # Num classes depents on the task_variable

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
data_seed: 0

#splits
split_params:
learning_setting: inductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting

# Dataloader parameters
dataloader_params:
batch_size: 5
num_workers: 0
pin_memory: False
persistent_workers: False
40 changes: 40 additions & 0 deletions configs/dataset/simplicial/mantra_orientation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
loader:
_target_: topobenchmark.data.loaders.MantraSimplicialDatasetLoader
parameters:
data_domain: simplicial
data_type: topological
data_name: MANTRA
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
manifold_dim: 2
version: "v0.0.5"
task_variable: "orientable" # Options: ['name', 'genus', 'orientable'] To use 'torsion_coefficients', 'betti_numbers' fix multilabel multiclass issue
model_domain: ${model.model_domain}

# Data definition
parameters:
# In the case of higher-order datasets we have multiple feature dimentions
num_features: [1,1,1]
num_classes: 2 # Num classes depents on the task_variable

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
data_seed: 0

#splits
split_params:
learning_setting: inductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting

# Dataloader parameters
dataloader_params:
batch_size: 5
num_workers: 0
pin_memory: False
persistent_workers: False
2 changes: 1 addition & 1 deletion configs/evaluator/classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ task_level: ${dataset.parameters.task_level}
num_classes: ${dataset.parameters.num_classes}

# Metrics
metrics: [accuracy, precision, recall, auroc] #Available options: accuracy, auroc, precision, recall
metrics: [accuracy, precision, recall, auroc] # Available options: accuracy, auroc, precision, recall
6 changes: 3 additions & 3 deletions configs/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- dataset: graph/ZINC
- model: cell/topotune
- transforms: ${get_default_transform:${dataset},${model}} #tree #${get_default_transform:${dataset},${model}} #no_transform
- dataset: simplicial/mantra_orientation
- model: simplicial/scn
- transforms: ${get_default_transform:${dataset},${model}} #no_transform
- optimizer: default
- loss: default
- evaluator: default
Expand Down
2 changes: 2 additions & 0 deletions configs/transforms/liftings/cell2hypergraph_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
2 changes: 2 additions & 0 deletions configs/transforms/liftings/cell2simplicial_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
2 changes: 2 additions & 0 deletions configs/transforms/liftings/hypergraph2cell_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
2 changes: 2 additions & 0 deletions configs/transforms/liftings/simplicial2graph_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
17 changes: 13 additions & 4 deletions test/data/load/test_datasetloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ def _gather_config_files(self, base_dir: Path) -> List[str]:
"""
config_files = []
config_base_dir = base_dir / "configs/dataset"
exclude_datasets = {# Below the datasets that have some default transforms manually overriten with no_transform,
# Below the datasets that have some default transforms manually overriten with no_transform,
exclude_datasets = {"karate_club.yaml",
# Below the datasets that have some default transforms with we manually overriten with no_transform,
# due to lack of default transform for domain2domain
"REDDIT-BINARY.yaml", "IMDB-MULTI.yaml", "IMDB-BINARY.yaml", #"ZINC.yaml"
}

# Below the datasets that takes quite some time to load and process
self.long_running_datasets = {"mantra_name.yaml", "mantra_orientation.yaml", "mantra_genus.yaml", "mantra_betti_numbers.yaml",}


for dir_path in config_base_dir.iterdir():
Expand Down Expand Up @@ -75,12 +80,16 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]:
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"],
return_hydra_config=True

return_hydra_config=True,
)
dataset_loader = hydra.utils.instantiate(parameters.dataset.loader)
print(repr(dataset_loader))
return dataset_loader.load()

if config_file in self.long_running_datasets:
dataset, data_dir = dataset_loader.load(slice=100)
else:
dataset, data_dir = dataset_loader.load()
return dataset, data_dir

def test_dataset_loading_states(self):
"""Test different states and scenarios during dataset loading."""
Expand Down
10 changes: 10 additions & 0 deletions test/utils/test_config_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_get_default_transform(self):
out = get_default_transform("graph/ZINC", "cell/can")
assert out == "dataset_defaults/ZINC"


def test_get_required_lifting(self):
"""Test get_required_lifting."""
out = get_required_lifting("graph", "graph/gat")
Expand Down Expand Up @@ -106,6 +107,15 @@ def test_infer_in_channels(self):
in_channels = infer_in_channels(cfg.dataset, cfg.transforms)
assert in_channels == [1433,1433,1433]

cfg = hydra.compose(config_name="run.yaml", overrides=["model=graph/gcn", "dataset=simplicial/mantra_orientation"], return_hydra_config=True)
in_channels = infer_in_channels(cfg.dataset, cfg.transforms)
assert in_channels == [1]

cfg = hydra.compose(config_name="run.yaml", overrides=["model=simplicial/scn", "dataset=graph/cocitation_cora"], return_hydra_config=True)
in_channels = infer_in_channels(cfg.dataset, cfg.transforms)
assert in_channels == [1433,1433,1433]



def test_infer_num_cell_dimensions(self):
"""Test infer_num_cell_dimensions."""
Expand Down
Loading

0 comments on commit 58083bd

Please sign in to comment.