Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match performance with stable-baselines (discrete case) #110

Merged
merged 15 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Callable]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
Expand Down Expand Up @@ -474,6 +475,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_dones = np.zeros((self._last_obs.shape[0],), dtype=np.bool)
Miffyli marked this conversation as resolved.
Show resolved Hide resolved
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def collect_rollouts(
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs)
self._last_obs = new_obs
self._last_dones = dones

rollout_buffer.compute_returns_and_advantage(values, dones=dones)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch 🙈 Please don't tell me that solve your performance issue.

I know where it comes from ... I shouldn't have copy-pasted from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L169

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about that, we need to double check VecFrameStack, even though it is the same as SB2.

Copy link
Collaborator Author

@Miffyli Miffyli Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly (luckily? =) ) it did not fix the issues yet. SB3 is still consistently worse in a few of the Atari games I have tested. I am in the process of step-by-step comparisons, will see how that goes.

Edit: Ah yes, stacking on the wrong channels?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or having that kind of issue: ikostrikov/pytorch-a2c-ppo-acktr-gail@84a7582

btw, is it better now with OMP_NUM_THREADS=1 w.r.t. fps? (maybe you should write in the comment the current stand of SB2 vs SB3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing that may change is the optimizer implementation and default parameters, for the initialization, I think (at least I tried) to reproduce what was done in SB2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question was more what is the fps we want to reach? (what did you have with SB2?)

Hmm I do not have conclusive numbers just yet because I have been running many experiments on same system and can not guarantee fair comparisons, but I think PyTorch variants are about 10% slower with Atari games and 25% slower on toy environments. The latter required the OMP_NUM_THREADS tuning. This sounds reasonable, given the non-compiled nature of PyTorch and the fact the code has not been optimized much yet.

Yes, the issue was that nminibatches lead to different mini-batchsize depending on the number of environments

Ah alright. I will write big notes about this on the "moving from stable-baselines" docs :)

Copy link
Contributor

@m-rph m-rph Jul 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One major change in parameters is the use of batch_size=64 rather than nminibatches=4 in PPO. Using such small batch-size made things very slow FPS-wise, but in some cases sped up the learning. I will focus more on these running-speed things in an another PR.

I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a whole batch at once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started documenting the migration here ;)
#123

I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a while batch at once.

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistyped, I meant that if we store a whole batch at once, we should get a sizeable speedup over storing one transition at a time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not sure what you mean...

nn.ReLU(),
nn.Flatten(),
)
Expand Down