From 307e58537ab90fb285d3a082af89b8190ee0aadb Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 19:11:40 +0000 Subject: [PATCH 01/22] feat: added support for wandb logging (single runs, not sweeps yet) --- src/gflownet/config.py | 3 +++ src/gflownet/tasks/seh_frag_moo.py | 5 +++++ src/gflownet/trainer.py | 3 +++ 3 files changed, 11 insertions(+) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index be4fa879..3da9e227 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -78,6 +78,8 @@ class Config: The git hash of the current commit overwrite_existing_exp : bool Whether to overwrite the contents of the log_dir if it already exists + use_wandb : bool + Whether to use Weights & Biases for logging """ log_dir: str = MISSING @@ -94,6 +96,7 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = True + use_wandb: bool = False algo: AlgoConfig = AlgoConfig() model: ModelConfig = ModelConfig() opt: OptimizerConfig = OptimizerConfig() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index bd597c31..8a046339 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch_geometric.data as gd +import wandb from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor @@ -380,6 +381,7 @@ def main(): "num_final_gen_steps": 50, "validate_every": 100, "num_workers": 0, + "use_wandb": True, "algo": { "global_batch_size": 64, "method": "TB", @@ -436,6 +438,9 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) + if hps["use_wandb"]: + wandb.init(project="gflownet", config=hps) + trial = SEHMOOFragTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 55e0159b..35c5d395 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd +import wandb from omegaconf import OmegaConf from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol @@ -368,6 +369,8 @@ def log(self, info, index, key): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) + if self.cfg.use_wandb: + wandb.log(info, step=index) def cycle(it): From e2afa44e7778403a9c621e38b9ab06d702104313 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 21:41:37 +0000 Subject: [PATCH 02/22] feat: replaced hps provided as a dict by a Config() object in seh_frag_moo --- src/gflownet/algo/config.py | 1 - src/gflownet/config.py | 22 +++++- src/gflownet/tasks/seh_frag_moo.py | 103 +++++++++-------------------- src/gflownet/trainer.py | 19 +++--- src/gflownet/utils/config.py | 1 + 5 files changed, 61 insertions(+), 85 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index f2bf178a..5433c25e 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -127,7 +127,6 @@ class AlgoConfig: valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 valid_random_action_prob: float = 0.0 - valid_sample_cond_info: bool = True sampling_tau: float = 0.0 tb: TBConfig = TBConfig() moql: MOQLConfig = MOQLConfig() diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 3da9e227..616384ac 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass from typing import Optional from omegaconf import MISSING @@ -95,7 +95,7 @@ class Config: hostname: Optional[str] = None pickle_mp_messages: bool = False git_hash: Optional[str] = None - overwrite_existing_exp: bool = True + overwrite_existing_exp: bool = False use_wandb: bool = False algo: AlgoConfig = AlgoConfig() model: ModelConfig = ModelConfig() @@ -103,3 +103,21 @@ class Config: replay: ReplayConfig = ReplayConfig() task: TasksConfig = TasksConfig() cond: ConditionalsConfig = ConditionalsConfig() + + +def init_missing(cfg: Config) -> Config: + """ + Initialize a dataclass instance with all fields set to MISSING, + including nested dataclasses. + + This is meant to be used on the user side (tasks) to provide + some configuration using the Config class while overwritting + only the fields that have been set by the user. + """ + for f in fields(cfg): + if is_dataclass(f.type): + setattr(cfg, f.name, init_missing(f.type())) + else: + setattr(cfg, f.name, MISSING) + + return cfg diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 8a046339..0c3e3f26 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -7,15 +7,15 @@ import torch import torch.nn as nn import torch_geometric.data as gd -import wandb from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset +import wandb from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -from gflownet.config import Config +from gflownet.config import Config, init_missing from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask @@ -217,7 +217,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.sampling_tau = 0.95 # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) # sampling and set the offline ratio to 1 - cfg.algo.valid_sample_cond_info = False + cfg.cond.valid_sample_cond_info = False cfg.algo.valid_offline_ratio = 1 def setup_algo(self): @@ -371,78 +371,37 @@ def __getitem__(self, idx): def main(): """Example of how this model can be run.""" - hps = { - "log_dir": "./logs/debug_run_sfm", - "device": "cuda" if torch.cuda.is_available() else "cpu", - "pickle_mp_messages": True, - "overwrite_existing_exp": True, - "seed": 0, - "num_training_steps": 500, - "num_final_gen_steps": 50, - "validate_every": 100, - "num_workers": 0, - "use_wandb": True, - "algo": { - "global_batch_size": 64, - "method": "TB", - "sampling_tau": 0.95, - "train_random_action_prob": 0.01, - "tb": { - "Z_learning_rate": 1e-3, - "Z_lr_decay": 50000, - }, - }, - "model": { - "num_layers": 2, - "num_emb": 256, - }, - "task": { - "seh_moo": { - "objectives": ["seh", "qed"], - "n_valid": 15, - "n_valid_repeats": 128, - }, - }, - "opt": { - "learning_rate": 1e-4, - "lr_decay": 20000, - }, - "cond": { - "temperature": { - "sample_dist": "constant", - "dist_params": [60.0], - "num_thermometer_dim": 32, - }, - "weighted_prefs": { - "preference_type": "dirichlet", - }, - "focus_region": { - "focus_type": None, # "learned-tabular", - "focus_cosim": 0.98, - "focus_limit_coef": 1e-1, - "focus_model_training_limits": (0.25, 0.75), - "focus_model_state_space_res": 30, - "max_train_it": 5_000, - }, - }, - "replay": { - "use": False, - "warmup": 1000, - "hindsight_ratio": 0.0, - }, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) + config = init_missing(Config()) + config.log_dir = "./logs/debug_run_sfm" + config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.print_every = 1 + config.validate_every = 1 + config.num_final_gen_steps = 5 + config.num_training_steps = 3 + config.pickle_mp_messages = True + config.overwrite_existing_exp = True + config.use_wandb = False + config.algo.sampling_tau = 0.95 + config.algo.train_random_action_prob = 0.01 + config.algo.tb.Z_learning_rate = 1e-3 + config.task.seh_moo.objectives = ["seh", "qed"] + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [60.0] + config.cond.weighted_prefs.preference_type = "dirichlet" + config.cond.focus_region.focus_type = None + config.replay.use = False + + if os.path.exists(config.log_dir): + if config.overwrite_existing_exp: + shutil.rmtree(config.log_dir) else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) + raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(config.log_dir) - if hps["use_wandb"]: - wandb.init(project="gflownet", config=hps) + if config.use_wandb: + wandb.init(project="gflownet", config=config) - trial = SEHMOOFragTrainer(hps) - trial.print_every = 1 + trial = SEHMOOFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 35c5d395..5e7d7776 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -92,15 +92,13 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class GFNTrainer: - def __init__(self, hps: Dict[str, Any]): + def __init__(self, config: Config): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters ---------- - hps: Dict[str, Any] - A dictionary of hyperparameters. These override default values obtained by the `set_default_hps` method. - device: torch.device - The torch device of the main worker. + config: Config + The hyperparameters for the trainer. """ # self.setup should at least set these up: self.training_data: Dataset @@ -120,11 +118,12 @@ def __init__(self, hps: Dict[str, Any]): # - The default values specified in individual config classes # - The default values specified in the `default_hps` method, typically what is defined by a task # - The values passed in the constructor, typically what is called by the user - # The final config is obtained by merging the three sources - self.cfg: Config = OmegaConf.structured(Config()) - self.set_default_hps(self.cfg) + # The final config is obtained by merging the three sources with the following precedence: + # config classes < default_hps < constructor (i.e. the constructor overrides the default_hps, and so on) + self.default_cfg: Config = OmegaConf.structured(Config()) + self.set_default_hps(self.default_cfg) # OmegaConf returns a fancy object but we can still pretend it's a Config instance - self.cfg = OmegaConf.merge(self.cfg, hps) # type: ignore + self.cfg = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) # Print the loss every `self.print_every` iterations @@ -230,7 +229,7 @@ def build_validation_data_loader(self) -> DataLoader: illegal_action_logreward=self.cfg.algo.illegal_action_logreward, ratio=self.cfg.algo.valid_offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), - sample_cond_info=self.cfg.algo.valid_sample_cond_info, + sample_cond_info=self.cfg.cond.valid_sample_cond_info, stream=False, random_action_prob=self.cfg.algo.valid_random_action_prob, ) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index db3d3905..5c9fa7bc 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -71,6 +71,7 @@ class FocusRegionConfig: @dataclass class ConditionalsConfig: + valid_sample_cond_info: bool = True temperature: TempCondConfig = TempCondConfig() moo: MultiObjectiveConfig = MultiObjectiveConfig() weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() From 22e0a6cbdfe629d130d5e09708ca8faadab47848 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 23:10:29 +0000 Subject: [PATCH 03/22] feat: added config.desc --- src/gflownet/config.py | 3 +++ src/gflownet/tasks/seh_frag_moo.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 616384ac..de570e08 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -50,6 +50,8 @@ class Config: Attributes ---------- + desc : str + A description of the experiment log_dir : str The directory where to store logs, checkpoints, and samples. device : str @@ -82,6 +84,7 @@ class Config: Whether to use Weights & Biases for logging """ + desc: str = "noDesc" log_dir: str = MISSING device: str = "cuda" seed: int = 0 diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 0c3e3f26..39068b7a 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -372,6 +372,7 @@ def __getitem__(self, idx): def main(): """Example of how this model can be run.""" config = init_missing(Config()) + config.desc = "debug_seh_frag_moo" config.log_dir = "./logs/debug_run_sfm" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.print_every = 1 @@ -380,7 +381,7 @@ def main(): config.num_training_steps = 3 config.pickle_mp_messages = True config.overwrite_existing_exp = True - config.use_wandb = False + config.use_wandb = True config.algo.sampling_tau = 0.95 config.algo.train_random_action_prob = 0.01 config.algo.tb.Z_learning_rate = 1e-3 @@ -399,7 +400,7 @@ def main(): os.makedirs(config.log_dir) if config.use_wandb: - wandb.init(project="gflownet", config=config) + wandb.init(project="gflownet", config=config, name=config.desc) trial = SEHMOOFragTrainer(config) trial.run() From 4a90c1a54577fd08da9ad1a57ec5df98af4190b2 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 23:14:21 +0000 Subject: [PATCH 04/22] fix: added train/valid for wandb log --- src/gflownet/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 5e7d7776..1484e23d 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -369,7 +369,7 @@ def log(self, info, index, key): for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) if self.cfg.use_wandb: - wandb.log(info, step=index) + wandb.log({f"{key}_{k}": v for k, v in info.items()}, step=index) def cycle(it): From 187c9709ab913aa88f62e22ae235a3d0713fddfe Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 23:28:17 +0000 Subject: [PATCH 05/22] fix: allow JSON serialization of Enum objects --- src/gflownet/algo/config.py | 2 +- src/gflownet/models/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 5433c25e..e3c45721 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -3,7 +3,7 @@ from typing import Optional -class TBVariant(Enum): +class TBVariant(int, Enum): """See algo.trajectory_balance.TrajectoryBalance for details.""" TB = 0 diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index e4955d29..05b00b6e 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -9,7 +9,7 @@ class GraphTransformerConfig: num_mlp_layers: int = 0 -class SeqPosEnc(Enum): +class SeqPosEnc(int, Enum): Pos = 0 Rotary = 1 From 235f058c0f7497df6348140ad135a427439ea284 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 13 Feb 2024 23:35:48 +0000 Subject: [PATCH 06/22] chore: tox --- src/gflownet/config.py | 8 ++++---- src/gflownet/tasks/seh_frag_moo.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index de570e08..55fd7467 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -108,19 +108,19 @@ class Config: cond: ConditionalsConfig = ConditionalsConfig() -def init_missing(cfg: Config) -> Config: +def init_empty(cfg): """ Initialize a dataclass instance with all fields set to MISSING, including nested dataclasses. - + This is meant to be used on the user side (tasks) to provide some configuration using the Config class while overwritting only the fields that have been set by the user. """ for f in fields(cfg): if is_dataclass(f.type): - setattr(cfg, f.name, init_missing(f.type())) + setattr(cfg, f.name, init_empty(f.type())) else: setattr(cfg, f.name, MISSING) - + return cfg diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 39068b7a..f7191166 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -7,15 +7,15 @@ import torch import torch.nn as nn import torch_geometric.data as gd +import wandb from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset -import wandb from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -from gflownet.config import Config, init_missing +from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask @@ -371,7 +371,7 @@ def __getitem__(self, idx): def main(): """Example of how this model can be run.""" - config = init_missing(Config()) + config = init_empty(Config()) config.desc = "debug_seh_frag_moo" config.log_dir = "./logs/debug_run_sfm" config.device = "cuda" if torch.cuda.is_available() else "cpu" From 98d2522502bb57aa78700878555d36d9a2e96f50 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 15:02:40 +0000 Subject: [PATCH 07/22] chore: replaced hps (dict) by Config() in all tasks. Moved qm9.py out of qm9/ --- src/gflownet/tasks/make_rings.py | 31 +++++++++++-------- src/gflownet/tasks/{qm9 => }/qm9.py | 26 +++++++++++----- src/gflownet/tasks/qm9/__init__.py | 0 src/gflownet/tasks/qm9/qm9.yaml | 10 ------ src/gflownet/tasks/seh_frag.py | 47 +++++++++++++---------------- src/gflownet/tasks/toy_seq.py | 46 +++++++++++++--------------- 6 files changed, 80 insertions(+), 80 deletions(-) rename src/gflownet/tasks/{qm9 => }/qm9.py (89%) delete mode 100644 src/gflownet/tasks/qm9/__init__.py delete mode 100644 src/gflownet/tasks/qm9/qm9.yaml diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index c3e8d0f9..496f292e 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -1,4 +1,5 @@ import os +import shutil import socket from typing import Dict, List, Tuple, Union @@ -8,7 +9,7 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar @@ -72,17 +73,23 @@ def setup_env_context(self): def main(): - hps = { - "log_dir": "./logs/debug_run_mr4", - "device": "cuda", - "num_training_steps": 10_000, - "num_workers": 8, - "algo": {"tb": {"do_parameterize_p_b": True}}, - } - os.makedirs(hps["log_dir"], exist_ok=True) - - trial = MakeRingsTrainer(hps) - trial.print_every = 1 + """Example of how this model can be run.""" + config = init_empty(Config()) + config.print_every = 1 + config.log_dir = "./logs/debug_run_mr4" + config.device = "cuda" + config.num_training_steps = 10_000 + config.num_workers = 8 + config.algo.tb.do_parameterize_p_b = True + + if os.path.exists(config.log_dir): + if config.overwrite_existing_exp: + shutil.rmtree(config.log_dir) + else: + raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(config.log_dir) + + trial = MakeRingsTrainer(config) trial.run() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9.py similarity index 89% rename from src/gflownet/tasks/qm9/qm9.py rename to src/gflownet/tasks/qm9.py index 866a7fac..212602e6 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -1,4 +1,5 @@ import os +import shutil from typing import Callable, Dict, List, Tuple, Union import numpy as np @@ -6,12 +7,11 @@ import torch.nn as nn import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol -from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer @@ -145,11 +145,23 @@ def setup_task(self): def main(): """Example of how this model can be run.""" - yaml = YAML(typ="safe", pure=True) - config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") - with open(config_file, "r") as f: - hps = yaml.load(f) - trial = QM9GapTrainer(hps) + config = init_empty(Config()) + config.num_workers = 0 + config.num_training_steps = 100000 + config.validate_every = 100 + config.log_dir = "./logs/debug_qm9" + config.opt.lr_decay = 10000 + config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5" + config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt" + + if os.path.exists(config.log_dir): + if config.overwrite_existing_exp: + shutil.rmtree(config.log_dir) + else: + raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(config.log_dir) + + trial = QM9GapTrainer(config) trial.run() diff --git a/src/gflownet/tasks/qm9/__init__.py b/src/gflownet/tasks/qm9/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml deleted file mode 100644 index 19701fac..00000000 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ /dev/null @@ -1,10 +0,0 @@ -opt: - lr_decay: 10000 -task: - qm9: - h5_path: /rxrx/data/chem/qm9/qm9.h5 - model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt -num_training_steps: 100000 -validate_every: 100 -log_dir: ./logs/debug_qm9 -num_workers: 0 diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e916f732..c163ccf7 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -12,7 +12,7 @@ from torch import Tensor from torch.utils.data import Dataset -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer @@ -185,33 +185,28 @@ def setup(self): def main(): - """Example of how this trainer can be run""" - hps = { - "log_dir": "./logs/debug_run_seh_frag_pb", - "device": "cuda" if torch.cuda.is_available() else "cpu", - "overwrite_existing_exp": True, - "num_training_steps": 10_000, - "num_workers": 8, - "opt": { - "lr_decay": 20000, - }, - "algo": {"sampling_tau": 0.99, "offline_ratio": 0.0}, - "cond": { - "temperature": { - "sample_dist": "uniform", - "dist_params": [0, 64.0], - } - }, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) + """Example of how this model can be run.""" + config = init_empty(Config()) + config.print_every = 1 + config.log_dir = "./logs/debug_run_seh_frag_pb" + config.device = "cuda" if torch.cuda.is_available() else "cpu" + config.overwrite_existing_exp = True + config.num_training_steps = 10_000 + config.num_workers = 8 + config.opt.lr_decay = 20_000 + config.algo.sampling_tau = 0.99 + config.algo.offline_ratio = 0.0 + config.cond.temperature.sample_dist = "uniform" + config.cond.temperature.dist_params = [0, 64.0] + + if os.path.exists(config.log_dir): + if config.overwrite_existing_exp: + shutil.rmtree(config.log_dir) else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) + raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(config.log_dir) - trial = SEHFragTrainer(hps) - trial.print_every = 1 + trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 7fe0f24b..dff78828 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -7,7 +7,7 @@ import torch from torch import Tensor -from gflownet.config import Config +from gflownet.config import Config, init_empty from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv from gflownet.models.seq_transformer import SeqTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer @@ -104,32 +104,28 @@ def setup_algo(self): def main(): - """Example of how this model can be run outside of Determined""" - hps = { - "log_dir": "./logs/debug_run_toy_seq", - "device": "cuda", - "overwrite_existing_exp": True, - "num_training_steps": 2_000, - "checkpoint_every": 200, - "num_workers": 4, - "cond": { - "temperature": { - "sample_dist": "constant", - "dist_params": [2.0], - "num_thermometer_dim": 1, - } - }, - "algo": {"train_random_action_prob": 0.05}, - } - if os.path.exists(hps["log_dir"]): - if hps["overwrite_existing_exp"]: - shutil.rmtree(hps["log_dir"]) + """Example of how this model can be run.""" + config = init_empty(Config()) + config.log_dir = "./logs/debug_run_toy_seq" + config.device = "cuda" + config.overwrite_existing_exp = True + config.num_training_steps = 2_000 + config.checkpoint_every = 200 + config.num_workers = 4 + config.print_every = 1 + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [2.0] + config.cond.temperature.num_thermometer_dim = 1 + config.algo.train_random_action_prob = 0.05 + + if os.path.exists(config.log_dir): + if config.overwrite_existing_exp: + shutil.rmtree(config.log_dir) else: - raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(hps["log_dir"]) + raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") + os.makedirs(config.log_dir) - trial = ToySeqTrainer(hps) - trial.print_every = 1 + trial = ToySeqTrainer(config) trial.run() From 5c5a751d7b2cf4d3958fc0414490f9079352a907 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 19:24:24 +0000 Subject: [PATCH 08/22] fix: changed default focus_region for frag_moo --- src/gflownet/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index 5c9fa7bc..dd9c57e7 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -60,7 +60,7 @@ class FocusRegionConfig: [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] """ - focus_type: Optional[str] = "learned-tabular" + focus_type: Optional[str] = "centered" use_steer_thermomether: bool = False focus_cosim: float = 0.98 focus_limit_coef: float = 0.1 From 0a244f26277fcfb5fb37f716c4027d530f89a573 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 19:24:58 +0000 Subject: [PATCH 09/22] fix: added assert to prevent inadvertently manipulating a Config rather than Config() object --- src/gflownet/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 1484e23d..6a2133bb 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -120,9 +120,9 @@ def __init__(self, config: Config): # - The values passed in the constructor, typically what is called by the user # The final config is obtained by merging the three sources with the following precedence: # config classes < default_hps < constructor (i.e. the constructor overrides the default_hps, and so on) - self.default_cfg: Config = OmegaConf.structured(Config()) + self.default_cfg: Config = Config() self.set_default_hps(self.default_cfg) - # OmegaConf returns a fancy object but we can still pretend it's a Config instance + assert isinstance(self.default_cfg, Config) and isinstance(config, Config) # make sure the config is a Config object, and not the Config class itself self.cfg = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) From 472c16a2fc6e4918134f66176b75a7094ef15f8e Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 19:50:08 +0000 Subject: [PATCH 10/22] removed cfg.use_wandb and simply test if wandb has been initialised in trainer --- src/gflownet/config.py | 3 --- src/gflownet/tasks/seh_frag_moo.py | 5 +---- src/gflownet/trainer.py | 2 +- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 55fd7467..52a10ee9 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -80,8 +80,6 @@ class Config: The git hash of the current commit overwrite_existing_exp : bool Whether to overwrite the contents of the log_dir if it already exists - use_wandb : bool - Whether to use Weights & Biases for logging """ desc: str = "noDesc" @@ -99,7 +97,6 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False - use_wandb: bool = False algo: AlgoConfig = AlgoConfig() model: ModelConfig = ModelConfig() opt: OptimizerConfig = OptimizerConfig() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index f7191166..d5806b27 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -230,6 +230,7 @@ def setup_algo(self): super().setup_algo() def setup_task(self): + self.cfg.cond.moo.num_objectives = len(self.cfg.task.seh_moo.objectives) self.task = SEHMOOTask( dataset=self.training_data, cfg=self.cfg, @@ -381,7 +382,6 @@ def main(): config.num_training_steps = 3 config.pickle_mp_messages = True config.overwrite_existing_exp = True - config.use_wandb = True config.algo.sampling_tau = 0.95 config.algo.train_random_action_prob = 0.01 config.algo.tb.Z_learning_rate = 1e-3 @@ -399,9 +399,6 @@ def main(): raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(config.log_dir) - if config.use_wandb: - wandb.init(project="gflownet", config=config, name=config.desc) - trial = SEHMOOFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 6a2133bb..d894e7c9 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -368,7 +368,7 @@ def log(self, info, index, key): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) - if self.cfg.use_wandb: + if wandb.run is None: wandb.log({f"{key}_{k}": v for k, v in info.items()}, step=index) From 046edaf665fe19723d31ee207d9e09ec7094b2af Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 20:02:07 +0000 Subject: [PATCH 11/22] chore: adding comment --- src/gflownet/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index dd9c57e7..7e0ed3cf 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -29,7 +29,7 @@ class TempCondConfig: @dataclass class MultiObjectiveConfig: - num_objectives: int = 2 + num_objectives: int = 2 # TODO: Change that as it can conflict with cfg.task.seh_moo.num_objectives num_thermometer_dim: int = 16 From d64e19f1725669dcc76aec6ee2aa5ef894c56dd4 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 22:02:26 +0000 Subject: [PATCH 12/22] fix: typo --- src/gflownet/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d894e7c9..48ce193a 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -368,7 +368,7 @@ def log(self, info, index, key): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) - if wandb.run is None: + if wandb.run is not None: wandb.log({f"{key}_{k}": v for k, v in info.items()}, step=index) From d28d8603444e6d5b4626225543cbd2867e77ce6d Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 22:08:11 +0000 Subject: [PATCH 13/22] chore: added cfg.task.seh_moo.log_topk to de-clutter a bit --- src/gflownet/tasks/config.py | 1 + src/gflownet/tasks/seh_frag_moo.py | 26 ++++++++++++++++---------- src/gflownet/trainer.py | 4 +++- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 28960399..c9c04da8 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -24,6 +24,7 @@ class SEHMOOTaskConfig: n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) + log_topk: bool = False @dataclass diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index d5806b27..e72fea38 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn import torch_geometric.data as gd -import wandb from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor @@ -320,9 +319,11 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) - self.valid_sampling_hooks.append(self._top_k_hook) + + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + if self.cfg.task.seh_moo.log_topk: + self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -330,15 +331,20 @@ def build_callbacks(self): # We use this class-based setup to be compatible with the DeterminedAI API, but no direct # dependency is required. parent = self + callback_dict = {} + + if self.cfg.task.seh_moo.log_topk: + + class TopKMetricCB: + def on_validation_end(self, metrics: Dict[str, Any]): + top_k = parent._top_k_hook.finalize() + for i in range(len(top_k)): + metrics[f"topk_rewards_{i}"] = top_k[i] + print("validation end", metrics) - class TopKMetricCB: - def on_validation_end(self, metrics: Dict[str, Any]): - top_k = parent._top_k_hook.finalize() - for i in range(len(top_k)): - metrics[f"topk_rewards_{i}"] = top_k[i] - print("validation end", metrics) + callback_dict["topk"] = TopKMetricCB() - return {"topk": TopKMetricCB()} + return callback_dict def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 48ce193a..fb6e5959 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -122,7 +122,9 @@ def __init__(self, config: Config): # config classes < default_hps < constructor (i.e. the constructor overrides the default_hps, and so on) self.default_cfg: Config = Config() self.set_default_hps(self.default_cfg) - assert isinstance(self.default_cfg, Config) and isinstance(config, Config) # make sure the config is a Config object, and not the Config class itself + assert isinstance(self.default_cfg, Config) and isinstance( + config, Config + ) # make sure the config is a Config object, and not the Config class itself self.cfg = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) From d5d200ea47b560a404c3421f009937e7e278ee77 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 22:45:57 +0000 Subject: [PATCH 14/22] fix: added wandb to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d588c636..f080c529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dependencies = [ "pyro-ppl", "gpytorch", "omegaconf>=2.3", + "wandb", ] [project.optional-dependencies] From 39e71f86749b1b8c8eb79cb6447b04a189c2bf4a Mon Sep 17 00:00:00 2001 From: julienroyd Date: Wed, 14 Feb 2024 23:04:31 +0000 Subject: [PATCH 15/22] minor: file name change for consistency --- src/gflownet/online_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 98791be5..22510a2c 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -89,7 +89,7 @@ def setup(self): print("\n\nHyperparameters:\n") yaml = OmegaConf.to_yaml(self.cfg) print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w") as f: f.write(yaml) def step(self, loss: Tensor): From fa1d7b3456306273111e694591293a410893ce4a Mon Sep 17 00:00:00 2001 From: julienroyd Date: Fri, 16 Feb 2024 20:14:11 +0000 Subject: [PATCH 16/22] chore: centralised self.cfg.overwrite_existing_exp in GFNTrainer() (removed from all tasks to simplify mains) --- src/gflownet/tasks/make_rings.py | 9 --------- src/gflownet/tasks/qm9.py | 9 --------- src/gflownet/tasks/seh_frag.py | 9 --------- src/gflownet/tasks/seh_frag_moo.py | 9 --------- src/gflownet/tasks/toy_seq.py | 9 --------- src/gflownet/trainer.py | 10 ++++++++++ 6 files changed, 10 insertions(+), 45 deletions(-) diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 496f292e..9211d038 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -1,5 +1,3 @@ -import os -import shutil import socket from typing import Dict, List, Tuple, Union @@ -82,13 +80,6 @@ def main(): config.num_workers = 8 config.algo.tb.do_parameterize_p_b = True - if os.path.exists(config.log_dir): - if config.overwrite_existing_exp: - shutil.rmtree(config.log_dir) - else: - raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(config.log_dir) - trial = MakeRingsTrainer(config) trial.run() diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 212602e6..79fcc68e 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -1,5 +1,3 @@ -import os -import shutil from typing import Callable, Dict, List, Tuple, Union import numpy as np @@ -154,13 +152,6 @@ def main(): config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5" config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt" - if os.path.exists(config.log_dir): - if config.overwrite_existing_exp: - shutil.rmtree(config.log_dir) - else: - raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(config.log_dir) - trial = QM9GapTrainer(config) trial.run() diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index c163ccf7..268ed948 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,5 +1,3 @@ -import os -import shutil import socket from typing import Callable, Dict, List, Tuple, Union @@ -199,13 +197,6 @@ def main(): config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - if os.path.exists(config.log_dir): - if config.overwrite_existing_exp: - shutil.rmtree(config.log_dir) - else: - raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(config.log_dir) - trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index e72fea38..c792cd15 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,6 +1,4 @@ -import os import pathlib -import shutil from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np @@ -398,13 +396,6 @@ def main(): config.cond.focus_region.focus_type = None config.replay.use = False - if os.path.exists(config.log_dir): - if config.overwrite_existing_exp: - shutil.rmtree(config.log_dir) - else: - raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(config.log_dir) - trial = SEHMOOFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index dff78828..901baea6 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -1,5 +1,3 @@ -import os -import shutil import socket from typing import Dict, List, Tuple @@ -118,13 +116,6 @@ def main(): config.cond.temperature.num_thermometer_dim = 1 config.algo.train_random_action_prob = 0.05 - if os.path.exists(config.log_dir): - if config.overwrite_existing_exp: - shutil.rmtree(config.log_dir) - else: - raise ValueError(f"Log dir {config.log_dir} already exists. Set overwrite_existing_exp=True to delete it.") - os.makedirs(config.log_dir) - trial = ToySeqTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index fb6e5959..f195bfeb 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -1,5 +1,6 @@ import os import pathlib +import shutil from typing import Any, Callable, Dict, List, NewType, Optional, Tuple import numpy as np @@ -160,6 +161,15 @@ def step(self, loss: Tensor): raise NotImplementedError() def setup(self): + if os.path.exists(self.cfg.log_dir): + if self.cfg.overwrite_existing_exp: + shutil.rmtree(self.cfg.log_dir) + else: + raise ValueError( + f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." + ) + os.makedirs(self.cfg.log_dir) + RDLogger.DisableLog("rdApp.*") self.rng = np.random.default_rng(142857) self.env = GraphBuildingEnv() From c530a003e69cdc269e61a9974a646ce4f2dc4434 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Fri, 16 Feb 2024 20:19:59 +0000 Subject: [PATCH 17/22] feat: added hyperopt/wandb_demo --- src/gflownet/hyperopt/wandb_demo/README.md | 7 +++ .../hyperopt/wandb_demo/init_wandb_sweep.py | 52 +++++++++++++++++++ .../wandb_demo/launch_wandb_agents.sh | 19 +++++++ .../hyperopt/wandb_demo/wandb_agent_main.py | 10 ++++ 4 files changed, 88 insertions(+) create mode 100644 src/gflownet/hyperopt/wandb_demo/README.md create mode 100644 src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py create mode 100755 src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh create mode 100644 src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py diff --git a/src/gflownet/hyperopt/wandb_demo/README.md b/src/gflownet/hyperopt/wandb_demo/README.md new file mode 100644 index 00000000..a2013e43 --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/README.md @@ -0,0 +1,7 @@ +These are the two files used to execute wandb searches: +1. `init_wandb_sweep.py` defines the base-configuration and the hyperparameters to sweep over. +2. `wandb_agent_main.py` is executed by wandb agents that are managed by the wandb sweep. + +To launch the search +1. `python init_wandb_sweep.py` to intialize the sweep +2. `sbatch launch_wandb_agents.sh ` to schedule a jobarray in slurm which will launch wandb agents with `wandb_agent_main.py` as entrypoint. The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py new file mode 100644 index 00000000..6921ede2 --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -0,0 +1,52 @@ +from pathlib import Path + +import wandb + +import gflownet +from gflownet.config import Config, init_empty + +sweep_config = { + "name": "sehFragMoo-Zlr-Zlrdecay", + "program": "wandb_agent_main.py", + "controller": { + "type": "cloud", + }, + "method": "grid", + "parameters": { + "config.algo.tb.Z_learning_rate": {"values": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2]}, + "config.algo.tb.Z_lr_decay": {"values": [2_000, 10_000, 50_000, 250_000]}, + }, +} + + +def wandb_config_merger(): + config = init_empty(Config()) + wandb_config = wandb.config + + # Set desired config values + config.log_dir = str(Path(gflownet.__file__).parent / "sweeps" / sweep_config["name"] / "run_logs" / wandb.run.name) + config.print_every = 100 + config.validate_every = 1000 + config.num_final_gen_steps = 1000 + config.num_training_steps = 40_000 + config.pickle_mp_messages = True + config.overwrite_existing_exp = False + config.algo.sampling_tau = 0.95 + config.algo.train_random_action_prob = 0.01 + config.algo.tb.Z_learning_rate = 1e-3 + config.task.seh_moo.objectives = ["seh", "qed"] + config.cond.temperature.sample_dist = "constant" + config.cond.temperature.dist_params = [60.0] + config.cond.weighted_prefs.preference_type = "dirichlet" + config.cond.focus_region.focus_type = None + config.replay.use = False + + # Merge the wandb sweep config with the nested config from gflownet + config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"] + config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"] + + return config + + +if __name__ == "__main__": + wandb.sweep(sweep_config, entity="valencelabs", project="gflownet") diff --git a/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh new file mode 100755 index 00000000..03526f57 --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Purpose: Script to allocate a node and run a wandb sweep agent on it +# Usage: sbatch launch_wandb_agent.sh + +#SBATCH --job-name=wandb_sweep_agent +#SBATCH --array=1-20 +#SBATCH --time=23:59:00 +#SBATCH --output=slurm_output_files/%x_%N_%A_%a.out +#SBATCH --gpus=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=16GB +#SBATCH --partition compute + +source activate gfn-py39-torch113 +echo "Using environment={$CONDA_DEFAULT_ENV}" + +# launch wandb agent +wandb agent --count 1 --entity valencelabs --project gflownet $1 diff --git a/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py b/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py new file mode 100644 index 00000000..4a294c1f --- /dev/null +++ b/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py @@ -0,0 +1,10 @@ +import wandb +from init_wandb_sweep import wandb_config_merger + +import gflownet.tasks.seh_frag_moo as seh_frag_moo + +if __name__ == "__main__": + wandb.init(entity="valencelabs", project="gflownet") + config = wandb_config_merger() + trial = seh_frag_moo.SEHMOOFragTrainer(config) + trial.run() From 90228db25058e637d40ad0f2afdbbb559cae9a1d Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 27 Feb 2024 15:37:27 +0000 Subject: [PATCH 18/22] feat: removed wandb_agent_main.py to have the search and entrypoint defined in a single file --- src/gflownet/hyperopt/wandb_demo/README.md | 9 +++-- .../hyperopt/wandb_demo/init_wandb_sweep.py | 33 +++++++++++++++---- .../hyperopt/wandb_demo/wandb_agent_main.py | 10 ------ 3 files changed, 31 insertions(+), 21 deletions(-) delete mode 100644 src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py diff --git a/src/gflownet/hyperopt/wandb_demo/README.md b/src/gflownet/hyperopt/wandb_demo/README.md index a2013e43..54132978 100644 --- a/src/gflownet/hyperopt/wandb_demo/README.md +++ b/src/gflownet/hyperopt/wandb_demo/README.md @@ -1,7 +1,6 @@ -These are the two files used to execute wandb searches: -1. `init_wandb_sweep.py` defines the base-configuration and the hyperparameters to sweep over. -2. `wandb_agent_main.py` is executed by wandb agents that are managed by the wandb sweep. +Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents. -To launch the search +To launch the search: 1. `python init_wandb_sweep.py` to intialize the sweep -2. `sbatch launch_wandb_agents.sh ` to schedule a jobarray in slurm which will launch wandb agents with `wandb_agent_main.py` as entrypoint. The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. +2. `sbatch launch_wandb_agents.sh ` to schedule a jobarray in slurm which will launch wandb agents. +The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py index 6921ede2..59b0211b 100644 --- a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -1,13 +1,22 @@ -from pathlib import Path +import os +import time import wandb -import gflownet +import gflownet.tasks.seh_frag_moo as seh_frag_moo from gflownet.config import Config, init_empty +TIME = time.strftime("%m-%d-%H-%M") +ENTITY = "valencelabs" +PROJECT = "syngfn" +SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay" +STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}" + + +# Define the search space of the sweep sweep_config = { - "name": "sehFragMoo-Zlr-Zlrdecay", - "program": "wandb_agent_main.py", + "name": SWEEP_NAME, + "program": "init_wandb_sweep.py", "controller": { "type": "cloud", }, @@ -24,7 +33,7 @@ def wandb_config_merger(): wandb_config = wandb.config # Set desired config values - config.log_dir = str(Path(gflownet.__file__).parent / "sweeps" / sweep_config["name"] / "run_logs" / wandb.run.name) + config.log_dir = (f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}",) config.print_every = 100 config.validate_every = 1000 config.num_final_gen_steps = 1000 @@ -49,4 +58,16 @@ def wandb_config_merger(): if __name__ == "__main__": - wandb.sweep(sweep_config, entity="valencelabs", project="gflownet") + # if there are arguments, this is a wandb agent + if len(sweep_config["parameters"]) > 0: + wandb.init(entity="valencelabs", project="gflownet") + config = wandb_config_merger() + trial = seh_frag_moo.SEHMOOFragTrainer(config) + trial.run() + + # otherwise, initialize the sweep + else: + if os.path.exists(STORAGE_DIR): + raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.") + + wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT) diff --git a/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py b/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py deleted file mode 100644 index 4a294c1f..00000000 --- a/src/gflownet/hyperopt/wandb_demo/wandb_agent_main.py +++ /dev/null @@ -1,10 +0,0 @@ -import wandb -from init_wandb_sweep import wandb_config_merger - -import gflownet.tasks.seh_frag_moo as seh_frag_moo - -if __name__ == "__main__": - wandb.init(entity="valencelabs", project="gflownet") - config = wandb_config_merger() - trial = seh_frag_moo.SEHMOOFragTrainer(config) - trial.run() From d7dc5b2e5134b81805951f144b3c540cbba3bd8e Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 27 Feb 2024 15:51:03 +0000 Subject: [PATCH 19/22] chore: tox --- src/gflownet/algo/config.py | 2 -- src/gflownet/tasks/qm9.py | 2 +- src/gflownet/tasks/qm9_moo.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 90bdf1ea..0ccf2e0e 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -116,8 +116,6 @@ class AlgoConfig: Do not take random actions after this number of steps valid_random_action_prob : float The probability of taking a random action during validation - valid_sample_cond_info : bool - Whether to sample conditioning information during validation (if False, expects a validation set of cond_info) sampling_tau : float The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta) """ diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 967beac8..0e934906 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -163,6 +163,7 @@ def setup(self): self.training_data.setup(self.task, self.ctx) self.test_data.setup(self.task, self.ctx) + def main(): """Example of how this model can be run.""" config = init_empty(Config()) @@ -180,4 +181,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index cb0e8277..b1dab870 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -15,7 +15,7 @@ from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer +from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics @@ -197,7 +197,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.sampling_tau = 0.95 # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) # sampling and set the offline ratio to 1 - cfg.algo.valid_sample_cond_info = False + cfg.cond.valid_sample_cond_info = False cfg.algo.valid_offline_ratio = 1 def setup_algo(self): From df5d799c5b9a22e92941f1842d7d9d651db55df8 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 27 Feb 2024 16:30:23 +0000 Subject: [PATCH 20/22] fix: minor in wandb_demo --- .../hyperopt/wandb_demo/init_wandb_sweep.py | 24 +++++++++---------- .../wandb_demo/launch_wandb_agents.sh | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py index 59b0211b..c8d7de54 100644 --- a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -1,4 +1,5 @@ import os +import sys import time import wandb @@ -8,7 +9,7 @@ TIME = time.strftime("%m-%d-%H-%M") ENTITY = "valencelabs" -PROJECT = "syngfn" +PROJECT = "gflownet" SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay" STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}" @@ -22,8 +23,8 @@ }, "method": "grid", "parameters": { - "config.algo.tb.Z_learning_rate": {"values": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2]}, - "config.algo.tb.Z_lr_decay": {"values": [2_000, 10_000, 50_000, 250_000]}, + "config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]}, + "config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]}, }, } @@ -58,16 +59,15 @@ def wandb_config_merger(): if __name__ == "__main__": - # if there are arguments, this is a wandb agent - if len(sweep_config["parameters"]) > 0: - wandb.init(entity="valencelabs", project="gflownet") - config = wandb_config_merger() - trial = seh_frag_moo.SEHMOOFragTrainer(config) - trial.run() - - # otherwise, initialize the sweep - else: + # if there no arguments, initialize the sweep, otherwise this is a wandb agent + if len(sys.argv) == 1: if os.path.exists(STORAGE_DIR): raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.") wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT) + + else: + wandb.init(entity=ENTITY, project=PROJECT) + config = wandb_config_merger() + trial = seh_frag_moo.SEHMOOFragTrainer(config) + trial.run() diff --git a/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh index 03526f57..c1d27990 100755 --- a/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh +++ b/src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh @@ -4,7 +4,7 @@ # Usage: sbatch launch_wandb_agent.sh #SBATCH --job-name=wandb_sweep_agent -#SBATCH --array=1-20 +#SBATCH --array=1-6 #SBATCH --time=23:59:00 #SBATCH --output=slurm_output_files/%x_%N_%A_%a.out #SBATCH --gpus=1 From 2596dc475eb26b9c62c83710fddf385181191a08 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 27 Feb 2024 16:44:45 +0000 Subject: [PATCH 21/22] fix: storage path --- src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py | 6 +++--- src/gflownet/tasks/seh_frag_moo.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py index c8d7de54..a48ba71d 100644 --- a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -4,7 +4,7 @@ import wandb -import gflownet.tasks.seh_frag_moo as seh_frag_moo +from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer from gflownet.config import Config, init_empty TIME = time.strftime("%m-%d-%H-%M") @@ -34,7 +34,7 @@ def wandb_config_merger(): wandb_config = wandb.config # Set desired config values - config.log_dir = (f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}",) + config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}" config.print_every = 100 config.validate_every = 1000 config.num_final_gen_steps = 1000 @@ -69,5 +69,5 @@ def wandb_config_merger(): else: wandb.init(entity=ENTITY, project=PROJECT) config = wandb_config_merger() - trial = seh_frag_moo.SEHMOOFragTrainer(config) + trial = SEHMOOFragTrainer(config) trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index b5b5a51a..1f8787d3 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -395,6 +395,8 @@ def main(): config.cond.weighted_prefs.preference_type = "dirichlet" config.cond.focus_region.focus_type = None config.replay.use = False + config.task.seh_moo.n_valid = 15 + config.task.seh_moo.n_valid_repeats = 2 trial = SEHMOOFragTrainer(config) trial.run() From e17f70f1fb952b00afd45df2db02d5485bec5637 Mon Sep 17 00:00:00 2001 From: julienroyd Date: Tue, 27 Feb 2024 16:52:47 +0000 Subject: [PATCH 22/22] chore: tox --- src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py index a48ba71d..8927794a 100644 --- a/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py +++ b/src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py @@ -4,8 +4,8 @@ import wandb -from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer from gflownet.config import Config, init_empty +from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer TIME = time.strftime("%m-%d-%H-%M") ENTITY = "valencelabs"