diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 6184bdfc..0ccf2e0e 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -3,7 +3,7 @@ from typing import Optional -class TBVariant(Enum): +class TBVariant(int, Enum): """See algo.trajectory_balance.TrajectoryBalance for details.""" TB = 0 @@ -116,8 +116,6 @@ class AlgoConfig: Do not take random actions after this number of steps valid_random_action_prob : float The probability of taking a random action during validation - valid_sample_cond_info : bool - Whether to sample conditioning information during validation (if False, expects a validation set of cond_info) sampling_tau : float The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) """ @@ -133,7 +131,6 @@ class AlgoConfig: train_random_action_prob: float = 0.0 train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 - valid_sample_cond_info: bool = True sampling_tau: float = 0.0 tb: TBConfig = TBConfig() moql: MOQLConfig = MOQLConfig() diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 6941e7a7..782b4ff4 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass from typing import Optional from omegaconf import MISSING @@ -50,6 +50,8 @@ class Config: Attributes ---------- + desc : str + A description of the experiment log_dir : str The directory where to store logs, checkpoints, and samples. device : str @@ -82,6 +84,7 @@ class Config: Whether to overwrite the contents of the log_dir if it already exists """ + desc: str = "noDesc" log_dir: str = MISSING device: str = "cuda" seed: int = 0 @@ -96,10 +99,28 @@ class Config: hostname: Optional[str] = None pickle_mp_messages: bool = False git_hash: Optional[str] = None - overwrite_existing_exp: bool = True + overwrite_existing_exp: bool = False algo: AlgoConfig = AlgoConfig() model: ModelConfig = ModelConfig() opt: OptimizerConfig = OptimizerConfig() replay: ReplayConfig = ReplayConfig() task: TasksConfig = TasksConfig() cond: ConditionalsConfig = ConditionalsConfig() + + +def init_empty(cfg): + """ + Initialize a dataclass instance with all fields set to MISSING, + including nested dataclasses. + + This is meant to be used on the user side (tasks) to provide + some configuration using the Config class while overwritting + only the fields that have been set by the user. + """ + for f in fields(cfg): + if is_dataclass(f.type): + setattr(cfg, f.name, init_empty(f.type())) + else: + setattr(cfg, f.name, MISSING) + + return cfg diff --git a/src/gflownet/hyperopt/wandb_demo/README.md b/src/gflownet/hyperopt/wandb_demo/README.md new file mode 100644 index 00000000..54132978 --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/README.md @@ -0,0 +1,6 @@ +Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents. + +To launch the search: +1. `python init_wandb_sweep.py` to intialize the sweep +2. `sbatch launch_wandb_agents.sh ` to schedule a jobarray in slurm which will launch wandb agents. +The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py new file mode 100644 index 00000000..8927794a --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -0,0 +1,73 @@ +import os +import sys +import time + +import wandb + +from gflownet.config import Config, init_empty +from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer + +TIME = time.strftime("%m-%d-%H-%M") +ENTITY = "valencelabs" +PROJECT = "gflownet" +SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay" +STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}" + + +# Define the search space of the sweep +sweep_config = { + "name": SWEEP_NAME, + "program": "init_wandb_sweep.py", + "controller": { + "type": "cloud", + }, + "method": "grid", + "parameters": { + "config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]}, + "config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]}, + }, +} + + +def wandb_config_merger(): + config = init_empty(Config()) + wandb_config = wandb.config + + # Set desired config values + config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}" + config.print_every = 100 + config.validate_every = 1000 + config.num_final_gen_steps = 1000 + config.num_training_steps = 40_000 + config.pickle_mp_messages = True + config.overwrite_existing_exp = False + config.algo.sampling_tau = 0.95 + config.algo.train_random_action_prob = 0.01 + config.algo.tb.Z_learning_rate = 1e-3 + config.task.seh_moo.objectives = ["seh", "qed"] + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [60.0] + config.cond.weighted_prefs.preference_type = "dirichlet" + config.cond.focus_region.focus_type = None + config.replay.use = False + + # Merge the wandb sweep config with the nested config from gflownet + config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"] + config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"] + + return config + + +if __name__ == "__main__": + # if there no arguments, initialize the sweep, otherwise this is a wandb agent + if len(sys.argv) == 1: + if os.path.exists(STORAGE_DIR): + raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.") + + wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT) + + else: + wandb.init(entity=ENTITY, project=PROJECT) + config = wandb_config_merger() + trial = SEHMOOFragTrainer(config) + trial.run() diff --git a/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh new file mode 100755 index 00000000..c1d27990 --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Purpose: Script to allocate a node and run a wandb sweep agent on it +# Usage: sbatch launch_wandb_agent.sh + +#SBATCH --job-name=wandb_sweep_agent +#SBATCH --array=1-6 +#SBATCH --time=23:59:00 +#SBATCH --output=slurm_output_files/%x_%N_%A_%a.out +#SBATCH --gpus=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=16GB +#SBATCH --partition compute + +source activate gfn-py39-torch113 +echo "Using environment={$CONDA_DEFAULT_ENV}" + +# launch wandb agent +wandb agent --count 1 --entity valencelabs --project gflownet $1 diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index e4955d29..05b00b6e 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -9,7 +9,7 @@ class GraphTransformerConfig: num_mlp_layers: int = 0 -class SeqPosEnc(Enum): +class SeqPosEnc(int, Enum): Pos = 0 Rotary = 1 diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index c815cce9..edda9c79 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -105,13 +105,13 @@ def setup(self): git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] self.cfg.git_hash = git_hash - yaml = OmegaConf.to_yaml(self.cfg) - os.makedirs(self.cfg.log_dir, exist_ok=True) - if self.print_hps: + yaml_cfg = OmegaConf.to_yaml(self.cfg) + if self.print_config: print("\n\nHyperparameters:\n") - print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f: - f.write(yaml) + print(yaml_cfg) + os.makedirs(self.cfg.log_dir, exist_ok=True) + with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: + f.write(yaml_cfg) def step(self, loss: Tensor): loss.backward() diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 7e7df30d..4c29f634 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -26,6 +26,7 @@ class SEHMOOTaskConfig: n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) + log_topk: bool = False online_pareto_front: bool = True diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index c3e8d0f9..9211d038 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -1,4 +1,3 @@ -import os import socket from typing import Dict, List, Tuple, Union @@ -8,7 +7,7 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar @@ -72,17 +71,16 @@ def setup_env_context(self): def main(): - hps = { - "log_dir": "./logs/debug_run_mr4", - "device": "cuda", - "num_training_steps": 10_000, - "num_workers": 8, - "algo": {"tb": {"do_parameterize_p_b": True}}, - } - os.makedirs(hps["log_dir"], exist_ok=True) - - trial = MakeRingsTrainer(hps) - trial.print_every = 1 + """Example of how this model can be run.""" + config = init_empty(Config()) + config.print_every = 1 + config.log_dir = "./logs/debug_run_mr4" + config.device = "cuda" + config.num_training_steps = 10_000 + config.num_workers = 8 + config.algo.tb.do_parameterize_p_b = True + + trial = MakeRingsTrainer(config) trial.run() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9.py similarity index 92% rename from src/gflownet/tasks/qm9/qm9.py rename to src/gflownet/tasks/qm9.py index d66f571a..0e934906 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer @@ -162,3 +162,22 @@ def setup(self): super().setup() self.training_data.setup(self.task, self.ctx) self.test_data.setup(self.task, self.ctx) + + +def main(): + """Example of how this model can be run.""" + config = init_empty(Config()) + config.num_workers = 0 + config.num_training_steps = 100000 + config.validate_every = 100 + config.log_dir = "./logs/debug_qm9" + config.opt.lr_decay = 10000 + config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5" + config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt" + + trial = QM9GapTrainer(config) + trial.run() + + +if __name__ == "__main__": + main() diff --git a/src/gflownet/tasks/qm9/__init__.py b/src/gflownet/tasks/qm9/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/gflownet/tasks/qm9/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py similarity index 99% rename from src/gflownet/tasks/qm9/qm9_moo.py rename to src/gflownet/tasks/qm9_moo.py index cb0e8277..b1dab870 100644 --- a/src/gflownet/tasks/qm9/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -15,7 +15,7 @@ from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer +from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics @@ -197,7 +197,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.sampling_tau = 0.95 # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) # sampling and set the offline ratio to 1 - cfg.algo.valid_sample_cond_info = False + cfg.cond.valid_sample_cond_info = False cfg.algo.valid_offline_ratio = 1 def setup_algo(self): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 91d65818..f95fa15e 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,5 +1,3 @@ -import os -import shutil import socket from typing import Callable, Dict, List, Tuple, Union @@ -13,7 +11,7 @@ from torch.utils.data import Dataset from torch_geometric.data import Data -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer @@ -200,33 +198,21 @@ def setup(self): def main(): - """Example of how this trainer can be run""" - hps = { - "log_dir": "./logs/debug_run_seh_frag_pb", - "device": "cuda" if torch.cuda.is_available() else "cpu", - "overwrite_existing_exp": True, - "num_training_steps": 10_000, - "num_workers": 8, - "opt": { - "lr_decay": 20000, - }, - "algo": {"sampling_tau": 0.99, "offline_ratio": 0.0}, - "cond": { - "temperature": { - "sample_dist": "uniform", - "dist_params": [0, 64.0], - } - }, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) - else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) - - trial = SEHFragTrainer(hps) - trial.print_every = 1 + """Example of how this model can be run.""" + config = init_empty(Config()) + config.print_every = 1 + config.log_dir = "./logs/debug_run_seh_frag_pb" + config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.overwrite_existing_exp = True + config.num_training_steps = 10_000 + config.num_workers = 8 + config.opt.lr_decay = 20_000 + config.algo.sampling_tau = 0.99 + config.algo.offline_ratio = 0.0 + config.cond.temperature.sample_dist = "uniform" + config.cond.temperature.dist_params = [0, 64.0] + + trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 31d8f769..1f8787d3 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,6 +1,4 @@ -import os import pathlib -import shutil from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np @@ -14,7 +12,7 @@ from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask @@ -223,7 +221,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.sampling_tau = 0.95 # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) # sampling and set the offline ratio to 1 - cfg.algo.valid_sample_cond_info = False + cfg.cond.valid_sample_cond_info = False cfg.algo.valid_offline_ratio = 1 def setup_algo(self): @@ -236,6 +234,7 @@ def setup_algo(self): super().setup_algo() def setup_task(self): + self.cfg.cond.moo.num_objectives = len(self.cfg.task.seh_moo.objectives) self.task = SEHMOOTask( dataset=self.training_data, cfg=self.cfg, @@ -324,9 +323,11 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) - self.valid_sampling_hooks.append(self._top_k_hook) + + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + if self.cfg.task.seh_moo.log_topk: + self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -334,15 +335,20 @@ def build_callbacks(self): # We use this class-based setup to be compatible with the DeterminedAI API, but no direct # dependency is required. parent = self + callback_dict = {} + + if self.cfg.task.seh_moo.log_topk: - class TopKMetricCB: - def on_validation_end(self, metrics: Dict[str, Any]): - top_k = parent._top_k_hook.finalize() - for i in range(len(top_k)): - metrics[f"topk_rewards_{i}"] = top_k[i] - print("validation end", metrics) + class TopKMetricCB: + def on_validation_end(self, metrics: Dict[str, Any]): + top_k = parent._top_k_hook.finalize() + for i in range(len(top_k)): + metrics[f"topk_rewards_{i}"] = top_k[i] + print("validation end", metrics) - return {"topk": TopKMetricCB()} + callback_dict["topk"] = TopKMetricCB() + + return callback_dict def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: @@ -370,74 +376,29 @@ def __getitem__(self, idx): def main(): """Example of how this model can be run.""" - hps = { - "log_dir": "./logs/debug_run_sfm", - "device": "cuda" if torch.cuda.is_available() else "cpu", - "pickle_mp_messages": True, - "overwrite_existing_exp": True, - "seed": 0, - "num_training_steps": 500, - "num_final_gen_steps": 50, - "validate_every": 100, - "num_workers": 0, - "algo": { - "global_batch_size": 64, - "method": "TB", - "sampling_tau": 0.95, - "train_random_action_prob": 0.01, - "tb": { - "Z_learning_rate": 1e-3, - "Z_lr_decay": 50000, - }, - }, - "model": { - "num_layers": 2, - "num_emb": 256, - }, - "task": { - "seh_moo": { - "objectives": ["seh", "qed"], - "n_valid": 15, - "n_valid_repeats": 128, - }, - }, - "opt": { - "learning_rate": 1e-4, - "lr_decay": 20000, - }, - "cond": { - "temperature": { - "sample_dist": "constant", - "dist_params": [60.0], - "num_thermometer_dim": 32, - }, - "weighted_prefs": { - "preference_type": "dirichlet", - }, - "focus_region": { - "focus_type": None, # "learned-tabular", - "focus_cosim": 0.98, - "focus_limit_coef": 1e-1, - "focus_model_training_limits": (0.25, 0.75), - "focus_model_state_space_res": 30, - "max_train_it": 5_000, - }, - }, - "replay": { - "use": False, - "warmup": 1000, - "hindsight_ratio": 0.0, - }, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) - else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) - - trial = SEHMOOFragTrainer(hps) - trial.print_every = 1 + config = init_empty(Config()) + config.desc = "debug_seh_frag_moo" + config.log_dir = "./logs/debug_run_sfm" + config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.print_every = 1 + config.validate_every = 1 + config.num_final_gen_steps = 5 + config.num_training_steps = 3 + config.pickle_mp_messages = True + config.overwrite_existing_exp = True + config.algo.sampling_tau = 0.95 + config.algo.train_random_action_prob = 0.01 + config.algo.tb.Z_learning_rate = 1e-3 + config.task.seh_moo.objectives = ["seh", "qed"] + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [60.0] + config.cond.weighted_prefs.preference_type = "dirichlet" + config.cond.focus_region.focus_type = None + config.replay.use = False + config.task.seh_moo.n_valid = 15 + config.task.seh_moo.n_valid_repeats = 2 + + trial = SEHMOOFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 7fe0f24b..901baea6 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -1,5 +1,3 @@ -import os -import shutil import socket from typing import Dict, List, Tuple @@ -7,7 +5,7 @@ import torch from torch import Tensor -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv from gflownet.models.seq_transformer import SeqTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer @@ -104,32 +102,21 @@ def setup_algo(self): def main(): - """Example of how this model can be run outside of Determined""" - hps = { - "log_dir": "./logs/debug_run_toy_seq", - "device": "cuda", - "overwrite_existing_exp": True, - "num_training_steps": 2_000, - "checkpoint_every": 200, - "num_workers": 4, - "cond": { - "temperature": { - "sample_dist": "constant", - "dist_params": [2.0], - "num_thermometer_dim": 1, - } - }, - "algo": {"train_random_action_prob": 0.05}, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) - else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) - - trial = ToySeqTrainer(hps) - trial.print_every = 1 + """Example of how this model can be run.""" + config = init_empty(Config()) + config.log_dir = "./logs/debug_run_toy_seq" + config.device = "cuda" + config.overwrite_existing_exp = True + config.num_training_steps = 2_000 + config.checkpoint_every = 200 + config.num_workers = 4 + config.print_every = 1 + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [2.0] + config.cond.temperature.num_thermometer_dim = 1 + config.algo.train_random_action_prob = 0.05 + + trial = ToySeqTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index b7ef6d50..67c03e25 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd +import wandb from omegaconf import OmegaConf from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol @@ -104,17 +105,15 @@ def close(self): class GFNTrainer: - def __init__(self, hps: Dict[str, Any], print_hps=True): + def __init__(self, config: Config, print_config=True): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters ---------- - hps: Dict[str, Any] - A dictionary of hyperparameters. These override default values obtained by the `set_default_hps` method. - device: torch.device - The torch device of the main worker. + config: Config + The hyperparameters for the trainer. """ - self.print_hps = print_hps + self.print_config = print_config self.to_terminate: List[Closable] = [] # self.setup should at least set these up: self.training_data: Dataset @@ -134,11 +133,14 @@ def __init__(self, hps: Dict[str, Any], print_hps=True): # - The default values specified in individual config classes # - The default values specified in the `default_hps` method, typically what is defined by a task # - The values passed in the constructor, typically what is called by the user - # The final config is obtained by merging the three sources - self.cfg: Config = OmegaConf.structured(Config()) - self.set_default_hps(self.cfg) - # OmegaConf returns a fancy object but we can still pretend it's a Config instance - self.cfg = OmegaConf.merge(self.cfg, hps) # type: ignore + # The final config is obtained by merging the three sources with the following precedence: + # config classes < default_hps < constructor (i.e. the constructor overrides the default_hps, and so on) + self.default_cfg: Config = Config() + self.set_default_hps(self.default_cfg) + assert isinstance(self.default_cfg, Config) and isinstance( + config, Config + ) # make sure the config is a Config object, and not the Config class itself + self.cfg = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) # Print the loss every `self.print_every` iterations @@ -173,6 +175,15 @@ def step(self, loss: Tensor): raise NotImplementedError() def setup(self): + if os.path.exists(self.cfg.log_dir): + if self.cfg.overwrite_existing_exp: + shutil.rmtree(self.cfg.log_dir) + else: + raise ValueError( + f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." + ) + os.makedirs(self.cfg.log_dir) + RDLogger.DisableLog("rdApp.*") self.rng = np.random.default_rng(142857) self.env = GraphBuildingEnv() @@ -244,7 +255,7 @@ def build_validation_data_loader(self) -> DataLoader: illegal_action_logreward=self.cfg.algo.illegal_action_logreward, ratio=self.cfg.algo.valid_offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), - sample_cond_info=self.cfg.algo.valid_sample_cond_info, + sample_cond_info=self.cfg.cond.valid_sample_cond_info, stream=False, random_action_prob=self.cfg.algo.valid_random_action_prob, ) @@ -432,6 +443,8 @@ def log(self, info, index, key): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) + if wandb.run is not None: + wandb.log({f"{key}_{k}": v for k, v in info.items()}, step=index) def __del__(self): self.terminate() diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index 54d0660d..8f67af3a 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -29,7 +29,7 @@ class TempCondConfig: @dataclass class MultiObjectiveConfig: - num_objectives: int = 2 + num_objectives: int = 2 # TODO: Change that as it can conflict with cfg.task.seh_moo.num_objectives num_thermometer_dim: int = 16 @@ -61,7 +61,7 @@ class FocusRegionConfig: [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] """ - focus_type: Optional[str] = "learned-tabular" + focus_type: Optional[str] = "centered" use_steer_thermomether: bool = False focus_cosim: float = 0.98 focus_limit_coef: float = 0.1 @@ -72,6 +72,7 @@ class FocusRegionConfig: @dataclass class ConditionalsConfig: + valid_sample_cond_info: bool = True temperature: TempCondConfig = TempCondConfig() moo: MultiObjectiveConfig = MultiObjectiveConfig() weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig()