Skip to content

Commit

Permalink
Merge branch 'trunk' into add_small_graph_task
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 4, 2023
2 parents 8969470 + ec857a5 commit 9e489ab
Show file tree
Hide file tree
Showing 23 changed files with 516 additions and 249 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
MAJOR="0"
MINOR="0"
MINOR="1"
18 changes: 18 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ We separate experiment concerns in four categories:
- The Trainer class is responsible for instanciating everything, and running the training & testing loop

Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.


## Graphs

This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations.

Some notes:
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.


### Graph policies & graph action categoricals

The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch.

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ dependencies = [
"botorch",
"pyro-ppl",
"gpytorch",
"omegaconf",
"omegaconf>=2.3",
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class AlgoConfig:
offline_ratio: float
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
`self.sampling_model`
valid_offline_ratio: float
Idem but for validation, and `self.test_data`.
train_random_action_prob : float
The probability of taking a random action during training
valid_random_action_prob : float
Expand All @@ -108,6 +110,7 @@ class AlgoConfig:
max_edges: int = 128
illegal_action_logreward: float = -100
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
train_random_action_prob: float = 0.0
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
Expand Down
5 changes: 4 additions & 1 deletion src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def logZ(self, cond_info: Tensor) -> Tensor:


class TrajectoryBalance(GFNAlgorithm):
""" """
"""TB implementation, see
"Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain,
Emmanuel Bengio, Chen Sun, Yoshua Bengio"
https://arxiv.org/abs/2201.13259"""

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ class Config:
----------
log_dir : str
The directory where to store logs, checkpoints, and samples.
device : str
The device to use for training (either "cpu" or "cuda[:<device_id>]")
seed : int
The random seed
validate_every : int
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
The training step to start at (default: 0)
num_final_gen_steps : Optional[int]
Expand All @@ -78,9 +82,11 @@ class Config:

log_dir: str = MISSING
log_sampled_data: bool = True
device: str = "cuda"
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
num_training_steps: int = 10_000
Expand Down
22 changes: 21 additions & 1 deletion src/gflownet/data/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import rdkit.Chem as Chem
import torch
from torch.utils.data import Dataset


Expand Down Expand Up @@ -39,4 +40,23 @@ def __len__(self):
return len(self.idcs)

def __getitem__(self, idx):
return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]])
return (
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
)


def convert_h5():
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
f = tarfile.TarFile("qm9.xyz.tar", "r")
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
all_mols = []
for pt in f:
pt = f.extractfile(pt) # type: ignore
data = pt.read().decode().splitlines() # type: ignore
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
store = pd.HDFStore("qm9.h5", "w")
store["df"] = df
69 changes: 59 additions & 10 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.envs.graph_building_env import GraphActionCategorical


class SamplingIterator(IterableDataset):
Expand All @@ -30,11 +31,12 @@ def __init__(
self,
dataset: Dataset,
model: nn.Module,
cfg: Config,
ctx,
algo,
task,
device,
batch_size: int = 1,
illegal_action_logreward: float = -50,
ratio: float = 0.5,
stream: bool = True,
replay_buffer: ReplayBuffer = None,
Expand All @@ -51,14 +53,21 @@ def __init__(
model: nn.Module
The model we sample from (must be on CUDA already or share_memory() must be called so that
parameters are synchronized between each worker)
ctx:
The context for the environment, e.g. a MolBuildingEnvContext instance
algo:
The training algorithm, e.g. a TrajectoryBalance instance
task: GFNTask
A Task instance, e.g. a MakeRingsTask instance
device: torch.device
The device the model is on
replay_buffer: ReplayBuffer
The replay buffer for training on past data
batch_size: int
The number of trajectories, each trajectory will be comprised of many graphs, so this is
_not_ the batch size in terms of the number of graphs (that will depend on the task)
algo:
The training algorithm, e.g. a TrajectoryBalance instance
task: ConditionalTask
illegal_action_logreward: float
The logreward for invalid trajectories
ratio: float
The ratio of offline trajectories in the batch.
stream: bool
Expand All @@ -69,13 +78,16 @@ def __init__(
sample_cond_info: bool
If True (default), then the dataset is a dataset of points used in offline training.
If False, then the dataset is a dataset of preferences (e.g. used to validate the model)
random_action_prob: float
The probability of taking a random action, passed to the graph sampler
init_train_iter: int
The initial training iteration, incremented and passed to task.sample_conditional_information
"""
self.cfg = cfg
self.data = dataset
self.model = model
self.replay_buffer = replay_buffer
self.batch_size = self.cfg.algo.global_batch_size
self.batch_size = batch_size
self.illegal_action_logreward = illegal_action_logreward
self.offline_batch_size = int(np.ceil(self.batch_size * ratio))
self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio)))
self.ratio = ratio
Expand All @@ -89,6 +101,8 @@ def __init__(
self.random_action_prob = random_action_prob
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag

# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
Expand All @@ -100,7 +114,10 @@ def __init__(
# don't want to initialize per-worker things just yet, such as where the log the worker writes
# to. This must be done in __iter__, which is called by the DataLoader once this instance
# has been copied into a new python process.
self.log_dir = log_dir if cfg.log_sampled_data else None
import warnings

warnings.warn("Fix dependency on cfg.log_sampled_data")
self.log_dir = log_dir # if cfg.log_sampled_data else None
self.log = SQLiteLog()
self.log_hooks: List[Callable] = []
# TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db
Expand All @@ -122,6 +139,9 @@ def _idx_iterator(self):
if n == 0:
yield np.arange(0, 0)
return
assert (
self.offline_batch_size > 0
), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)"
if worker_info is None: # no multi-processing
start, end, wid = 0, n, -1
else: # split the data into chunks (per-worker)
Expand Down Expand Up @@ -232,9 +252,10 @@ def iterator(self):
# Compute scalar rewards from conditional information & flat rewards
flat_rewards = torch.stack(flat_rewards)
log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards)
log_rewards[torch.logical_not(is_valid)] = self.cfg.algo.illegal_action_logreward
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

# Computes some metrics
extra_info = {}
if not self.sample_cond_info:
# If we're using a dataset of preferences, the user may want to know the id of the preference
for i, j in zip(trajs, idcs):
Expand Down Expand Up @@ -304,7 +325,7 @@ def iterator(self):
cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards(
cond_info, log_rewards, flat_rewards, hindsight_idxs
)
log_rewards[torch.logical_not(is_valid)] = self.cfg.algo.illegal_action_logreward
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

# Construct batch
batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards)
Expand All @@ -317,9 +338,37 @@ def iterator(self):
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

# Only activate for debugging your environment or dataset (e.g. the dataset could be
# generating trajectories with illegal actions)
if self.do_validate_batch:
self.validate_batch(batch, trajs)

self.train_it += worker_info.num_workers if worker_info is not None else 1
yield batch

def validate_batch(self, batch, trajs):
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
[(batch.bck_actions, self.ctx.bck_action_type_order)]
if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order")
else []
):
mask_cat = GraphActionCategorical(
batch,
[self.model._action_type_to_mask(t, batch) for t in atypes],
[self.model._action_type_to_key[t] for t in atypes],
[None for _ in atypes],
)
masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits)
num_trajs = len(trajs)
batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens)
first_graph_idx = torch.zeros_like(batch.traj_lens)
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
if masked_action_is_used.sum() != 0:
invalid_idx = masked_action_is_used.argmax().item()
traj_idx = batch_idx[invalid_idx].item()
timestep = invalid_idx - first_graph_idx[traj_idx].item()
raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep])

def log_generated(self, trajs, rewards, flat_rewards, cond_info):
if self.log_molecule_smis:
mols = [
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def add_parent(a, new_g):
GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]),
new_g,
)
if len(g.nodes) == 1:
if len(g.nodes) == 1 and len(g.nodes[i]) == 1:
# The final node is degree 0, need this special case to remove it
# and end up with S0, the empty graph root
# and end up with S0, the empty graph root (but only if it has no attrs except 'v')
add_parent(
GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]),
graph_without_node(g, i),
Expand Down
Loading

0 comments on commit 9e489ab

Please sign in to comment.