Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backward sampling #110

Merged
merged 10 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/grid_cond_gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def sample_many(self, mbsize):
return log_ratio

def learn_from(self, it, batch):
if type(batch) is list:
if isinstance(batch, list):
bengioe marked this conversation as resolved.
Show resolved Hide resolved
log_ratio = torch.stack(batch, 0)
else:
log_ratio = batch
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ requires-python = ">=3.8,<3.10"
dynamic = ["version"]
dependencies = [
"torch==1.13.1",
"torch-geometric",
"torch-scatter",
"torch-sparse",
"torch-cluster",
"torch-geometric==2.3.1", # Pinning until we adapt the code to newer versions
"torch-scatter==2.1.1",
"torch-sparse==0.6.17",
"torch-cluster==1.6.1",
"rdkit",
"tables",
"scipy",
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TBConfig:
variant: TBVariant = TBVariant.TB
do_correct_idempotent: bool = False
do_parameterize_p_b: bool = False
do_sample_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
Z_learning_rate: float = 1e-4
Expand Down
118 changes: 116 additions & 2 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import copy
from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from gflownet.envs.graph_building_env import GraphAction, GraphActionType
from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType
from gflownet.models.graph_transformer import GraphTransformerGFN


def relabel(g: Graph, ga: GraphAction):
"""Relabel the nodes for g to 0-N, and the graph action ga applied to g.
This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be
labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that
creates a gap in 0-N, leading to a faulty encoding of the graph.
"""
rmap = dict(zip(g.nodes, range(len(g.nodes))))
if not len(g) and ga.action == GraphActionType.AddNode:
rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0
g = g.relabel_nodes(rmap)
if ga.source is not None:
ga.source = rmap[ga.source]
if ga.target is not None:
ga.target = rmap[ga.target]
return g, ga


class GraphSampler:
Expand Down Expand Up @@ -185,3 +203,99 @@ def not_done(lst):
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop)))
data[i]["is_sink"].append(1)
return data

def sample_backward_from_graphs(
self,
graphs: List[Graph],
model: Optional[nn.Module],
cond_info: Tensor,
dev: torch.device,
random_action_prob: float = 0.0,
):
"""Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution
over legal actions.

Parameters
----------
graphs: List[Graph]
List of Graph endpoints
model: nn.Module
Model whose forward() method returns GraphActionCategorical instances
cond_info: Tensor
Conditional information of each trajectory, shape (n, n_info)
dev: torch.device
Device on which data is manipulated
random_action_prob: float
Probability of taking a random action (only used if model parameterizes P_B)

"""
n = len(graphs)
done = [False] * n
data = [
{
"traj": [(graphs[i], GraphAction(GraphActionType.Stop))],
"is_valid": True,
"is_sink": [1],
"bck_a": [GraphAction(GraphActionType.Stop)],
"bck_logprobs": [0.0],
"result": graphs[i],
}
for i in range(n)
]

def not_done(lst):
return [e for i, e in enumerate(lst) if not done[i]]

if random_action_prob > 0:
raise NotImplementedError("Random action not implemented for backward sampling")

while sum(done) < n:
torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))]
not_done_mask = torch.tensor(done, device=dev).logical_not()
if model is not None:
_, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
else:
gbatch = self.ctx.collate(torch_graphs)
action_types = self.ctx.bck_action_type_order
masks = [getattr(gbatch, i.mask_name) for i in action_types]
bck_cat = GraphActionCategorical(
gbatch,
logits=[m * 1e6 for m in masks],
keys=[
# TODO: This is not very clean, could probably abstract this away somehow
GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]]
for t in action_types
],
masks=masks,
types=action_types,
)
bck_actions = bck_cat.sample()
graph_bck_actions = [
self.ctx.aidx_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions)
]
bck_logprobs = bck_cat.log_prob(bck_actions)

for i, j in zip(not_done(range(n)), range(n)):
if not done[i]:
g = graphs[i]
b_a = graph_bck_actions[j]
gp = self.env.step(g, b_a)
f_a = self.env.reverse(g, b_a)
graphs[i], f_a = relabel(gp, f_a)
data[i]["traj"].append((graphs[i], f_a))
data[i]["bck_a"].append(b_a)
data[i]["is_sink"].append(0)
data[i]["bck_logprobs"].append(bck_logprobs[j].item())
if len(graphs[i]) == 0:
done[i] = True

for i in range(n):
# See comments in sample_from_model
data[i]["traj"] = data[i]["traj"][::-1]
data[i]["bck_a"] = [GraphAction(GraphActionType.Stop)] + data[i]["bck_a"][::-1]
data[i]["is_sink"] = data[i]["is_sink"][::-1]
data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1)
if self.pad_with_terminal_state:
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop)))
data[i]["is_sink"].append(1)
return data
27 changes: 23 additions & 4 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -148,8 +148,8 @@ def create_training_data_from_own_samples(
----------
model: TrajectoryBalanceModel
The model being sampled
graphs: List[Graph]
List of N Graph endpoints
n: int
Number of trajectories to sample
cond_info: torch.tensor
Conditional information, shape (N, n_info)
random_action_prob: float
Expand All @@ -174,19 +174,38 @@ def create_training_data_from_own_samples(
data[i]["logZ"] = logZ_pred[i].item()
return data

def create_training_data_from_graphs(self, graphs):
def create_training_data_from_graphs(
self,
graphs,
model: Optional[TrajectoryBalanceModel] = None,
cond_info: Optional[Tensor] = None,
random_action_prob: Optional[float] = None,
):
"""Generate trajectories from known endpoints

Parameters
----------
graphs: List[Graph]
List of Graph endpoints
model: TrajectoryBalanceModel
The model being sampled
cond_info: torch.tensor
Conditional information, shape (N, n_info)
random_action_prob: float
Probability of taking a random action

Returns
-------
trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}]
A list of trajectories.
"""
if self.cfg.do_sample_p_b:
assert model is not None and cond_info is not None and random_action_prob is not None
dev = self.ctx.device
cond_info = cond_info.to(dev)
return self.graph_sampler.sample_backward_from_graphs(
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob
)
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
for traj in trajs:
n_back = [
Expand Down
11 changes: 7 additions & 4 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def __iter__(self):
)

# Sample some dataset data
mols, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], [])
graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], [])
flat_rewards = (
list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else []
)
graphs = [self.ctx.mol_to_graph(m) for m in mols]
trajs = self.algo.create_training_data_from_graphs(graphs)
trajs = self.algo.create_training_data_from_graphs(
graphs, self.model, cond_info["encoding"][:num_offline], 0
)

else: # If we're not sampling the conditionals, then the idcs refer to listed preferences
num_online = num_offline
Expand Down Expand Up @@ -411,7 +412,9 @@ def _make_results_table(self, types, names):
cur.close()

def insert_many(self, rows, column_names):
assert all([type(x) is str or not isinstance(x, Iterable) for x in rows[0]]), "rows must only contain scalars"
assert all(
[isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]]
), "rows must only contain scalars"
if not self._has_results_table:
self._make_results_table([type(i) for i in rows[0]], column_names)
cur = self.db.cursor()
Expand Down
Loading