Skip to content

Commit

Permalink
updates ci, adds sb3 test
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Mar 27, 2023
1 parent 8dc717c commit ee4ad36
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test,sb3]
pip install .[test]
- name: Download examples
run: |
make download_examples
Expand Down
1 change: 1 addition & 0 deletions godot_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_args():
parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]")
parser.add_argument("--experiment_name", default=None, type=str, help="The name of the experiment [only for rllib]")
parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process")

return parser.parse_known_args()


Expand Down
4 changes: 2 additions & 2 deletions godot_rl/wrappers/stable_baselines_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def step_wait(self):
raise NotImplementedError()


def stable_baselines_training(args, extras):
def stable_baselines_training(args, extras, n_steps=200000):
# TODO: Add cla etc for sb3
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup)

Expand All @@ -81,7 +81,7 @@ def stable_baselines_training(args, extras):
n_steps=32,
tensorboard_log="logs/log",
)
model.learn(200000)
model.learn(n_steps)

print("closing env")
env.close()
19 changes: 19 additions & 0 deletions tests/test_sb3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from godot_rl.core.godot_env import GodotEnv
from godot_rl.main import get_args

try:
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
except ImportError as e:

def stable_baselines_training(args, extras):
print("Import error when trying to use sb3, this is probably not installed try pip install godot-rl[sb3]")


def test_sb3_training():
args, extras = get_args()
args.env = "gdrl"
args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64"

stable_baselines_training(args, extras, n_steps=10000)

0 comments on commit ee4ad36

Please sign in to comment.