Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wandb-sweep example and clean-up use of configs #118

Merged
merged 23 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
307e585
feat: added support for wandb logging (single runs, not sweeps yet)
julienroyd Feb 13, 2024
e2afa44
feat: replaced hps provided as a dict by a Config() object in seh_fra…
julienroyd Feb 13, 2024
22e0a6c
feat: added config.desc
julienroyd Feb 13, 2024
4a90c1a
fix: added train/valid for wandb log
julienroyd Feb 13, 2024
187c970
fix: allow JSON serialization of Enum objects
julienroyd Feb 13, 2024
235f058
chore: tox
julienroyd Feb 13, 2024
98d2522
chore: replaced hps (dict) by Config() in all tasks. Moved qm9.py out…
julienroyd Feb 14, 2024
5c5a751
fix: changed default focus_region for frag_moo
julienroyd Feb 14, 2024
0a244f2
fix: added assert to prevent inadvertently manipulating a Config rath…
julienroyd Feb 14, 2024
472c16a
removed cfg.use_wandb and simply test if wandb has been initialised i…
julienroyd Feb 14, 2024
046edaf
chore: adding comment
julienroyd Feb 14, 2024
d64e19f
fix: typo
julienroyd Feb 14, 2024
d28d860
chore: added cfg.task.seh_moo.log_topk to de-clutter a bit
julienroyd Feb 14, 2024
d5d200e
fix: added wandb to dependencies
julienroyd Feb 14, 2024
39e71f8
minor: file name change for consistency
julienroyd Feb 14, 2024
fa1d7b3
chore: centralised self.cfg.overwrite_existing_exp in GFNTrainer() (r…
julienroyd Feb 16, 2024
c530a00
feat: added hyperopt/wandb_demo
julienroyd Feb 16, 2024
90228db
feat: removed wandb_agent_main.py to have the search and entrypoint d…
julienroyd Feb 27, 2024
5668b69
Merge branch 'trunk' into julien-wandb
julienroyd Feb 27, 2024
d7dc5b2
chore: tox
julienroyd Feb 27, 2024
df5d799
fix: minor in wandb_demo
julienroyd Feb 27, 2024
2596dc4
fix: storage path
julienroyd Feb 27, 2024
e17f70f
chore: tox
julienroyd Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional


class TBVariant(Enum):
class TBVariant(int, Enum):
"""See algo.trajectory_balance.TrajectoryBalance for details."""

TB = 0
Expand Down Expand Up @@ -116,8 +116,6 @@ class AlgoConfig:
Do not take random actions after this number of steps
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Whether to sample conditioning information during validation (if False, expects a validation set of cond_info)
sampling_tau : float
The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta)
"""
Expand All @@ -133,7 +131,6 @@ class AlgoConfig:
train_random_action_prob: float = 0.0
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
tb: TBConfig = TBConfig()
moql: MOQLConfig = MOQLConfig()
Expand Down
25 changes: 23 additions & 2 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, fields, is_dataclass
from typing import Optional

from omegaconf import MISSING
Expand Down Expand Up @@ -50,6 +50,8 @@ class Config:

Attributes
----------
desc : str
A description of the experiment
log_dir : str
The directory where to store logs, checkpoints, and samples.
device : str
Expand Down Expand Up @@ -82,6 +84,7 @@ class Config:
Whether to overwrite the contents of the log_dir if it already exists
"""

desc: str = "noDesc"
log_dir: str = MISSING
device: str = "cuda"
seed: int = 0
Expand All @@ -96,10 +99,28 @@ class Config:
hostname: Optional[str] = None
pickle_mp_messages: bool = False
git_hash: Optional[str] = None
overwrite_existing_exp: bool = True
overwrite_existing_exp: bool = False
algo: AlgoConfig = AlgoConfig()
model: ModelConfig = ModelConfig()
opt: OptimizerConfig = OptimizerConfig()
replay: ReplayConfig = ReplayConfig()
task: TasksConfig = TasksConfig()
cond: ConditionalsConfig = ConditionalsConfig()


def init_empty(cfg):
"""
Initialize a dataclass instance with all fields set to MISSING,
including nested dataclasses.

This is meant to be used on the user side (tasks) to provide
some configuration using the Config class while overwritting
only the fields that have been set by the user.
"""
for f in fields(cfg):
if is_dataclass(f.type):
setattr(cfg, f.name, init_empty(f.type()))
else:
setattr(cfg, f.name, MISSING)

return cfg
6 changes: 6 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents.

To launch the search:
1. `python init_wandb_sweep.py` to intialize the sweep
2. `sbatch launch_wandb_agents.sh <SWEEP_ID>` to schedule a jobarray in slurm which will launch wandb agents.
The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped.
73 changes: 73 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/init_wandb_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys
import time

import wandb

from gflownet.config import Config, init_empty
from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer

TIME = time.strftime("%m-%d-%H-%M")
ENTITY = "valencelabs"
PROJECT = "gflownet"
SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay"
STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}"


# Define the search space of the sweep
sweep_config = {
"name": SWEEP_NAME,
"program": "init_wandb_sweep.py",
"controller": {
"type": "cloud",
},
"method": "grid",
"parameters": {
"config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]},
"config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]},
},
}


def wandb_config_merger():
config = init_empty(Config())
wandb_config = wandb.config

# Set desired config values
config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}"
config.print_every = 100
config.validate_every = 1000
config.num_final_gen_steps = 1000
config.num_training_steps = 40_000
config.pickle_mp_messages = True
config.overwrite_existing_exp = False
config.algo.sampling_tau = 0.95
config.algo.train_random_action_prob = 0.01
config.algo.tb.Z_learning_rate = 1e-3
config.task.seh_moo.objectives = ["seh", "qed"]
config.cond.temperature.sample_dist = "constant"
config.cond.temperature.dist_params = [60.0]
config.cond.weighted_prefs.preference_type = "dirichlet"
config.cond.focus_region.focus_type = None
config.replay.use = False

# Merge the wandb sweep config with the nested config from gflownet
config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"]
config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"]

return config


if __name__ == "__main__":
# if there no arguments, initialize the sweep, otherwise this is a wandb agent
if len(sys.argv) == 1:
if os.path.exists(STORAGE_DIR):
raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.")

wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT)

else:
wandb.init(entity=ENTITY, project=PROJECT)
config = wandb_config_merger()
trial = SEHMOOFragTrainer(config)
trial.run()
19 changes: 19 additions & 0 deletions src/gflownet/hyperopt/wandb_demo/launch_wandb_agents.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

# Purpose: Script to allocate a node and run a wandb sweep agent on it
# Usage: sbatch launch_wandb_agent.sh <SWEEP_ID>

#SBATCH --job-name=wandb_sweep_agent
#SBATCH --array=1-6
#SBATCH --time=23:59:00
#SBATCH --output=slurm_output_files/%x_%N_%A_%a.out
#SBATCH --gpus=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=16GB
#SBATCH --partition compute

source activate gfn-py39-torch113
echo "Using environment={$CONDA_DEFAULT_ENV}"

# launch wandb agent
wandb agent --count 1 --entity valencelabs --project gflownet $1
2 changes: 1 addition & 1 deletion src/gflownet/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class GraphTransformerConfig:
num_mlp_layers: int = 0


class SeqPosEnc(Enum):
class SeqPosEnc(int, Enum):
Pos = 0
Rotary = 1

Expand Down
12 changes: 6 additions & 6 deletions src/gflownet/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def setup(self):
git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7]
self.cfg.git_hash = git_hash

yaml = OmegaConf.to_yaml(self.cfg)
os.makedirs(self.cfg.log_dir, exist_ok=True)
if self.print_hps:
yaml_cfg = OmegaConf.to_yaml(self.cfg)
if self.print_config:
print("\n\nHyperparameters:\n")
print(yaml)
with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f:
f.write(yaml)
print(yaml_cfg)
os.makedirs(self.cfg.log_dir, exist_ok=True)
with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f:
f.write(yaml_cfg)

def step(self, loss: Tensor):
loss.backward()
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/tasks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SEHMOOTaskConfig:
n_valid: int = 15
n_valid_repeats: int = 128
objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"])
log_topk: bool = False
online_pareto_front: bool = True


Expand Down
24 changes: 11 additions & 13 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import socket
from typing import Dict, List, Tuple, Union

Expand All @@ -8,7 +7,7 @@
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor

from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
Expand Down Expand Up @@ -72,17 +71,16 @@ def setup_env_context(self):


def main():
hps = {
"log_dir": "./logs/debug_run_mr4",
"device": "cuda",
"num_training_steps": 10_000,
"num_workers": 8,
"algo": {"tb": {"do_parameterize_p_b": True}},
}
os.makedirs(hps["log_dir"], exist_ok=True)

trial = MakeRingsTrainer(hps)
trial.print_every = 1
"""Example of how this model can be run."""
config = init_empty(Config())
config.print_every = 1
config.log_dir = "./logs/debug_run_mr4"
config.device = "cuda"
config.num_training_steps = 10_000
config.num_workers = 8
config.algo.tb.do_parameterize_p_b = True

trial = MakeRingsTrainer(config)
trial.run()


Expand Down
21 changes: 20 additions & 1 deletion src/gflownet/tasks/qm9/qm9.py → src/gflownet/tasks/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset

import gflownet.models.mxmnet as mxmnet
from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.data.qm9 import QM9Dataset
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
Expand Down Expand Up @@ -162,3 +162,22 @@ def setup(self):
super().setup()
self.training_data.setup(self.task, self.ctx)
self.test_data.setup(self.task, self.ctx)


def main():
"""Example of how this model can be run."""
config = init_empty(Config())
config.num_workers = 0
config.num_training_steps = 100000
config.validate_every = 100
config.log_dir = "./logs/debug_qm9"
config.opt.lr_decay = 10000
config.task.qm9.h5_path = "/rxrx/data/chem/qm9/qm9.h5"
config.task.qm9.model_path = "/rxrx/data/chem/qm9/mxmnet_gap_model.pt"

trial = QM9GapTrainer(config)
trial.run()


if __name__ == "__main__":
main()
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gflownet.config import Config
from gflownet.data.qm9 import QM9Dataset
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer
from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer
from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks
from gflownet.trainer import FlatRewards, RewardScalar
from gflownet.utils import metrics
Expand Down Expand Up @@ -197,7 +197,7 @@ def set_default_hps(self, cfg: Config):
cfg.algo.sampling_tau = 0.95
# We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info)
# sampling and set the offline ratio to 1
cfg.algo.valid_sample_cond_info = False
cfg.cond.valid_sample_cond_info = False
cfg.algo.valid_offline_ratio = 1

def setup_algo(self):
Expand Down
46 changes: 16 additions & 30 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import shutil
import socket
from typing import Callable, Dict, List, Tuple, Union

Expand All @@ -13,7 +11,7 @@
from torch.utils.data import Dataset
from torch_geometric.data import Data

from gflownet.config import Config
from gflownet.config import Config, init_empty
from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph
from gflownet.models import bengio2021flow
from gflownet.online_trainer import StandardOnlineTrainer
Expand Down Expand Up @@ -200,33 +198,21 @@ def setup(self):


def main():
"""Example of how this trainer can be run"""
hps = {
"log_dir": "./logs/debug_run_seh_frag_pb",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"overwrite_existing_exp": True,
"num_training_steps": 10_000,
"num_workers": 8,
"opt": {
"lr_decay": 20000,
},
"algo": {"sampling_tau": 0.99, "offline_ratio": 0.0},
"cond": {
"temperature": {
"sample_dist": "uniform",
"dist_params": [0, 64.0],
}
},
}
if os.path.exists(hps["log_dir"]):
if hps["overwrite_existing_exp"]:
shutil.rmtree(hps["log_dir"])
else:
raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.")
os.makedirs(hps["log_dir"])

trial = SEHFragTrainer(hps)
trial.print_every = 1
"""Example of how this model can be run."""
config = init_empty(Config())
config.print_every = 1
config.log_dir = "./logs/debug_run_seh_frag_pb"
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.overwrite_existing_exp = True
config.num_training_steps = 10_000
config.num_workers = 8
config.opt.lr_decay = 20_000
config.algo.sampling_tau = 0.99
config.algo.offline_ratio = 0.0
config.cond.temperature.sample_dist = "uniform"
config.cond.temperature.dist_params = [0, 64.0]

trial = SEHFragTrainer(config)
trial.run()


Expand Down
Loading
Loading