From 64186f5bf4a2b5390144f4431aedc5e21ca6bce1 Mon Sep 17 00:00:00 2001 From: Vlad Firoiu Date: Fri, 25 Oct 2024 17:51:53 -0400 Subject: [PATCH] train two jit --- slippi_ai/rl/train_two_lib.py | 5 +++-- tests/train_two.sh | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/slippi_ai/rl/train_two_lib.py b/slippi_ai/rl/train_two_lib.py index 7bad766..446601b 100644 --- a/slippi_ai/rl/train_two_lib.py +++ b/slippi_ai/rl/train_two_lib.py @@ -42,7 +42,6 @@ class RuntimeConfig: max_runtime: tp.Optional[int] = None # maximum runtime in seconds log_interval: int = 10 # seconds between logging save_interval: int = 300 # seconds between saving to disk - use_fake_data: bool = False # Periodically reset the environments to deal with memory leaks in dolphin. reset_every_n_steps: tp.Optional[int] = None @@ -57,6 +56,7 @@ class AgentConfig: name: str = nametags.DEFAULT_NAME compile: bool = True + jit_compile: bool = True batch_steps: int = 0 async_inference: bool = False @@ -225,6 +225,7 @@ def agent_kwargs(self) -> dict: state=self.get_state(), name=self.agent_config.name, compile=self.agent_config.compile, + jit_compile=self.agent_config.jit_compile, batch_steps=self.agent_config.batch_steps, async_inference=self.agent_config.async_inference, ) @@ -365,7 +366,7 @@ def run(config: Config): num_envs=config.actor.num_envs, async_envs=config.actor.async_envs, use_gpu=config.actor.gpu_inference, - use_fake_envs=config.runtime.use_fake_data, + use_fake_envs=config.actor.use_fake_envs, # Rewards are overridden in the learner. ) diff --git a/tests/train_two.sh b/tests/train_two.sh index bfbcb68..b898cec 100755 --- a/tests/train_two.sh +++ b/tests/train_two.sh @@ -2,9 +2,9 @@ python slippi_ai/rl/train_two.py \ --config.runtime.max_step=10 \ --config.runtime.log_interval=0 \ --config.learner.learning_rate=0 \ - --config.runtime.use_fake_data=True \ --config.p1.teacher=slippi_ai/data/checkpoints/demo \ --config.p2.teacher=slippi_ai/data/checkpoints/demo \ + --config.actor.use_fake_envs=True \ --config.actor.num_envs=1 \ --config.actor.rollout_length=64 \ --config.runtime.burnin_steps_after_reset=1 \