Skip to content

Commit

Permalink
num rollouts per env for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Dec 5, 2024
1 parent 9b18617 commit 9aeb714
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions ldp/alg/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def _run_eval_loop(
max_rollout_steps: int | None,
callbacks: Sequence[Callback],
shuffle: bool = False,
num_rollouts_per_env: int = 1,
) -> None:
await asyncio.gather(*[callback.before_eval_loop() for callback in callbacks])

Expand All @@ -48,16 +49,20 @@ async def _run_eval_loop(
# We use pbar.n as a counter for the number of training steps
while pbar.n < num_iterations:
for batch in dataset.iter_batches(batch_size, shuffle=shuffle):
trajectories = await rollout_manager.sample_trajectories(
environments=batch, max_steps=max_rollout_steps
)
all_trajectories: list[Trajectory] = []

for _ in range(num_rollouts_per_env):
trajectories = await rollout_manager.sample_trajectories(
environments=batch, max_steps=max_rollout_steps
)
all_trajectories.extend(trajectories)

# Close the environment after we have sampled from it,
# in case it needs to tear down resources.
await _close_envs(batch, rollout_manager.catch_env_failures)
# Close the environment after we have sampled from it,
# in case it needs to tear down resources.
await _close_envs(batch, rollout_manager.catch_env_failures)

await asyncio.gather(*[
callback.after_eval_step(trajectories) for callback in callbacks
callback.after_eval_step(all_trajectories) for callback in callbacks
])
pbar.update()

Expand All @@ -79,6 +84,7 @@ class EvaluatorConfig(BaseModel):
"If 0, will not run the eval loop. "
),
)
num_rollouts_per_env: int = 1
max_rollout_steps: int | None = None
catch_agent_failures: bool = True
catch_env_failures: bool = True
Expand Down Expand Up @@ -131,6 +137,7 @@ async def run(self, **kwargs) -> None:
"num_iterations": self.config.num_eval_iterations,
"max_rollout_steps": self.config.max_rollout_steps,
"shuffle": self.config.shuffle,
"num_rollouts_per_env": self.config.num_rollouts_per_env,
}
eval_kwargs |= kwargs
await _run_eval_loop(
Expand Down

0 comments on commit 9aeb714

Please sign in to comment.