Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds multiple env support to the SB3 wrapper #123

Merged
merged 12 commits into from
Jul 13, 2023
Merged

Adds multiple env support to the SB3 wrapper #123

merged 12 commits into from
Jul 13, 2023

Conversation

edbeeching
Copy link
Owner

@edbeeching edbeeching commented Jul 11, 2023

Adds an option to interact with parallel Godot executables using the n_parallel argument. This only works in on an export env.

Example usage with 4 parallel envs:

env = StableBaselinesGodotEnv(env_path=args.env_path, n_parallel=4)

@edbeeching edbeeching marked this pull request as ready for review July 12, 2023 09:35
Copy link
Collaborator

@visuallization visuallization left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uuuuh nice, this looks really good. Also nice small improvements regarding typings, documentation and warnings. I will test this straight away and will get back to you.

@Ivan-267
Copy link
Collaborator

Ivan-267 commented Jul 12, 2023

Looks great and worked properly in my quick test.

Regarding the last commit:

parser.add_argument("--n_parallel", default=1, type=int, help="whether to speed up the physics in the env")

A small note is that the help tip can be updated as well, e.g. "How many instances of the environment executable to launch - requires --env_path to be set if > 1.".

I'm not sure whether this should be added, it's just here as a note, potentially setting a unique seed for each environment could encourage exploration:

from:

self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port+p, **kwargs) for p in range(n_parallel)]

to e.g. (untested):

        self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port+p, seed=p, **kwargs) for p in range(n_parallel)]

@edbeeching
Copy link
Owner Author

edbeeching commented Jul 12, 2023

@Ivan-267 , good point about the seed and help description. Thanks.
I also added a check in the StableBaselinesGodotEnv init method to ensure env_path is not None when n_parallel>1

@edbeeching edbeeching merged commit 85b787a into main Jul 13, 2023
@edbeeching edbeeching deleted the sb3-multi-env branch July 13, 2023 19:24
all_obs.extend(obs)
all_rewards.extend(reward)
all_term.extend(term)
all_trunc.extend(trunc)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the truncated should be stored in info, and info should include terminal obs too, see https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/dummy_vec_env.py#L61-L70

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@araffin Thanks for this info. I have a question.

On the Godot side, envs are automatically reset when an episode terminates / truncates, I am not sure how feasible it is to send the terminal obs in the truncated setting without refactoring all the envs. Does the sb3 PPO implementation use the terminal obs for a value preduction in the truncated setting? If it is not possible to refactor in the short term, would you recommend just setting the terminated flag? (with trunc = False for all steps)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the sb3 PPO implementation use the terminal obs for a value preduction

yes https://github.com/DLR-RM/stable-baselines3/blob/5abd50a853e0f667f48ae10769f478c4972eda35/stable_baselines3/common/on_policy_algorithm.py#L194-L205

If it is not possible to refactor in the short term, would you recommend just setting the terminated flag?

you can use done = truncated or terminated (and not set the terminal obs), although it might impact performance for infinite-horizon tasks.
(as a quick check, you can use this wrapper: https://sb3-contrib.readthedocs.io/en/master/common/wrappers.html#timefeaturewrapper and see if there is any performance diff)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants