diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 47d2623d..7a32e7ed 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install .[test,sb3] + pip install .[test] - name: Download examples run: | make download_examples diff --git a/godot_rl/main.py b/godot_rl/main.py index d3226642..b331bf30 100644 --- a/godot_rl/main.py +++ b/godot_rl/main.py @@ -66,6 +66,7 @@ def get_args(): parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]") parser.add_argument("--experiment_name", default=None, type=str, help="The name of the experiment [only for rllib]") parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") + return parser.parse_known_args() diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index 4ea083d3..7ae1132f 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -69,7 +69,7 @@ def step_wait(self): raise NotImplementedError() -def stable_baselines_training(args, extras): +def stable_baselines_training(args, extras, n_steps=200000): # TODO: Add cla etc for sb3 env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup) @@ -81,7 +81,7 @@ def stable_baselines_training(args, extras): n_steps=32, tensorboard_log="logs/log", ) - model.learn(200000) + model.learn(n_steps) print("closing env") env.close() diff --git a/tests/test_sb3_training.py b/tests/test_sb3_training.py new file mode 100644 index 00000000..4a800824 --- /dev/null +++ b/tests/test_sb3_training.py @@ -0,0 +1,19 @@ +import pytest + +from godot_rl.core.godot_env import GodotEnv +from godot_rl.main import get_args + +try: + from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training +except ImportError as e: + + def stable_baselines_training(args, extras): + print("Import error when trying to use sb3, this is probably not installed try pip install godot-rl[sb3]") + + +def test_sb3_training(): + args, extras = get_args() + args.env = "gdrl" + args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" + + stable_baselines_training(args, extras, n_steps=10000)