Skip to content

Commit

Permalink
fix: made focus_dir and preferences accessible at the batch level
Browse files Browse the repository at this point in the history
  • Loading branch information
julienroyd committed Apr 4, 2024
1 parent d7d6997 commit ad0db6c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def create_batch(self, trajs, batch_info):
batch.num_online = sum(t.get("is_online", 0) for t in trajs)
batch.num_offline = len(trajs) - batch.num_online
batch.extra_info = batch_info
if "preferences" in trajs[0]:
batch.preferences = torch.stack([t["preferences"] for t in trajs])
if "focus_dir" in trajs[0]:
batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs])
if "preferences" in trajs[0]['cond_info'].keys():
batch.preferences = torch.stack([t['cond_info']["preferences"] for t in trajs])
if "focus_dir" in trajs[0]['cond_info'].keys():
batch.focus_dir = torch.stack([t['cond_info']["focus_dir"] for t in trajs])

if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n:
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
Expand Down

0 comments on commit ad0db6c

Please sign in to comment.