From 9aeb714610526ceed8474274157f6496a8ed56c1 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Thu, 5 Dec 2024 21:52:10 +0000 Subject: [PATCH] num rollouts per env for eval --- ldp/alg/runners.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ldp/alg/runners.py b/ldp/alg/runners.py index 2363c47..c078519 100644 --- a/ldp/alg/runners.py +++ b/ldp/alg/runners.py @@ -28,6 +28,7 @@ async def _run_eval_loop( max_rollout_steps: int | None, callbacks: Sequence[Callback], shuffle: bool = False, + num_rollouts_per_env: int = 1, ) -> None: await asyncio.gather(*[callback.before_eval_loop() for callback in callbacks]) @@ -48,16 +49,20 @@ async def _run_eval_loop( # We use pbar.n as a counter for the number of training steps while pbar.n < num_iterations: for batch in dataset.iter_batches(batch_size, shuffle=shuffle): - trajectories = await rollout_manager.sample_trajectories( - environments=batch, max_steps=max_rollout_steps - ) + all_trajectories: list[Trajectory] = [] + + for _ in range(num_rollouts_per_env): + trajectories = await rollout_manager.sample_trajectories( + environments=batch, max_steps=max_rollout_steps + ) + all_trajectories.extend(trajectories) - # Close the environment after we have sampled from it, - # in case it needs to tear down resources. - await _close_envs(batch, rollout_manager.catch_env_failures) + # Close the environment after we have sampled from it, + # in case it needs to tear down resources. + await _close_envs(batch, rollout_manager.catch_env_failures) await asyncio.gather(*[ - callback.after_eval_step(trajectories) for callback in callbacks + callback.after_eval_step(all_trajectories) for callback in callbacks ]) pbar.update() @@ -79,6 +84,7 @@ class EvaluatorConfig(BaseModel): "If 0, will not run the eval loop. " ), ) + num_rollouts_per_env: int = 1 max_rollout_steps: int | None = None catch_agent_failures: bool = True catch_env_failures: bool = True @@ -131,6 +137,7 @@ async def run(self, **kwargs) -> None: "num_iterations": self.config.num_eval_iterations, "max_rollout_steps": self.config.max_rollout_steps, "shuffle": self.config.shuffle, + "num_rollouts_per_env": self.config.num_rollouts_per_env, } eval_kwargs |= kwargs await _run_eval_loop(