Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MultiBinary / MultiDiscrete spaces #13

Merged
merged 29 commits into from
May 18, 2020

Conversation

rolandgvc
Copy link
Contributor

@rolandgvc rolandgvc commented May 9, 2020

Description

  • Added support for MultiDiscrete and MultiBinary observation / action spaces for PPO and A2C
  • Added MultiCategorical and Bernoulli distributions
  • Added tests for MultiCategorical and Bernoulli distributions and actions spaces

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

closes #5
closes #4

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have checked the codestyle using make lint
  • I have ensured pytest and pytype both pass.

@araffin araffin added the PR template not filled Please fill the pull request template label May 9, 2020
@araffin
Copy link
Member

araffin commented May 9, 2020

Even though it is a draft for now, please keep the full PR template ;)

@araffin
Copy link
Member

araffin commented May 9, 2020

You can push on the same branch afterward to update the PR

@rolandgvc
Copy link
Contributor Author

@araffin are the namings and overall code design ok?

@araffin araffin self-requested a review May 10, 2020 08:54
@araffin araffin changed the title multicategorical dist and test Multicategorical distribution May 10, 2020
@araffin araffin removed the PR template not filled Please fill the pull request template label May 10, 2020
@araffin
Copy link
Member

araffin commented May 10, 2020

Not sure to have the time today to review... and Gitlab CI does work for forks yet, so I created a branch to check the status: https://github.com/DLR-RM/stable-baselines3/tree/rolandgvc/master

@araffin
Copy link
Member

araffin commented May 10, 2020

some errors in the pipeline: https://gitlab.com/araffin/stable-baselines3/pipelines/144486506

@rolandgvc
Copy link
Contributor Author

Fixed the List problem.

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this is good work =)
Some minor things (double checking the shapes returned)

Missing:

  • running tests with the current algorithms (you can find test environments for that in identity_env.py)
  • updating the changelog

PS: I recommend you to run the tests locally (at least the type check which is fast to run)

stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/preprocessing.py Outdated Show resolved Hide resolved
stable_baselines3/common/preprocessing.py Show resolved Hide resolved
stable_baselines3/common/preprocessing.py Outdated Show resolved Hide resolved
tests/test_distributions.py Outdated Show resolved Hide resolved
tests/test_distributions.py Outdated Show resolved Hide resolved
@araffin araffin changed the title Multicategorical distribution Additional distributions May 10, 2020
@araffin araffin changed the title Additional distributions Additional action spaces support May 10, 2020
@araffin
Copy link
Member

araffin commented May 10, 2020

Before I forget, you should also update policies.py for the preprocessing (there is some commented code there)

@rolandgvc
Copy link
Contributor Author

Thanks for the input 👍 Will have everything done in the next couple of days

@rolandgvc
Copy link
Contributor Author

Shouldn't dim take a vector with the sizes of each discrete space?

class IdentityEnvMultiDiscrete(IdentityEnv):
    def __init__(self, dim: int = 1, ep_length: int = 100):
        """
        Identity environment for testing purposes

        :param dim: (int) the size of the dimensions you want to learn
        :param ep_length: (int) the length of each episode in timesteps
        """
        space = MultiDiscrete([dim, dim])
        super().__init__(ep_length=ep_length, space=space)

Like here:

class MultiCategoricalDistribution(Distribution):
    """
    MultiCategorical distribution for multi discrete actions.

    :param action_dims: ([int]) List of sizes of discrete action spaces
    """

    def __init__(self, action_dims: List[int]):
        super(MultiCategoricalDistribution, self).__init__()
        self.action_dims = action_dims
        self.distributions = None

@araffin
Copy link
Member

araffin commented May 12, 2020

Shouldn't dim take a vector with the sizes of each discrete space?

yes, I think this choice was made for simplicity (at the end we use [dim, dim] for the shape) but you can change that ;)

@araffin
Copy link
Member

araffin commented May 12, 2020

@rolandgvc you should pull, I updated your PR to include automated tests ;)

@rolandgvc
Copy link
Contributor Author

I'm having the following error when trying to train SAC with the multidiscrete identity env:

observation: tensor([[1, 2]])
one hot: tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.]])
Traceback (most recent call last):
  File "test.py", line 17, in <module>
    model.learn(1000)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/sac/sac.py", line 268, in learn
    log_interval=log_interval)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/common/base_class.py", line 812, in collect_rollouts
    unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/common/base_class.py", line 321, in predict
    return self.policy.predict(observation, state, mask, deterministic)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/common/policies.py", line 232, in predict
    actions = self._predict(observation, deterministic=deterministic)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/sac/policies.py", line 367, in _predict
    return self.actor(observation, deterministic)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/sac/policies.py", line 169, in forward
    mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/sac/policies.py", line 154, in get_action_dist_params
    latent_pi = self.latent_pi(features)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/Applications/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [2 x 4], m2: [2 x 256] at ../aten/src/TH/generic/THTensorMath.cpp:41

Any hints where I could look?

@rolandgvc
Copy link
Contributor Author

Or maybe I am misunderstanding how multidiscrete observations work?

@araffin
Copy link
Member

araffin commented May 12, 2020

I'm having the following error when trying to train SAC with the multidiscrete identity env:

Be careful, I think you need to change the action space when using SAC with this env (because SAC and TD3 only support continuous actions, aka Box space)

@rolandgvc
Copy link
Contributor Author

And for observation space [3,3], I'm getting:

observation: tensor([[2, 2]])
one hot: tensor([[0., 0., 1.],
        [0., 0., 1.]])
Traceback (most recent call last):
  File "test.py", line 17, in <module>
    model.learn(1000)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/sac/sac.py", line 268, in learn
    log_interval=log_interval)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/common/base_class.py", line 854, in collect_rollouts
    replay_buffer.add(self._last_original_obs, new_obs_, buffer_action, reward_, done)
  File "/Users/rolandgavrilescu/Github/stable-baselines3/stable_baselines3/common/buffers.py", line 181, in add
    self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (3,2) into shape (1,2)

@rolandgvc
Copy link
Contributor Author

How do I change the action space?

@rolandgvc
Copy link
Contributor Author

@araffin have a look

@araffin
Copy link
Member

araffin commented May 14, 2020

I will try to do a full review today, but no promise.
In the meantime, I added a checkbox (about codestyle, see #19 ) in the PR template. Please fix the issues related to your PR (you can leave the other ones), see "link with flake8" in https://github.com/DLR-RM/stable-baselines3/pull/13/checks?check_run_id=672562856

Looking at the test logs, there are some worrying warnings "test with pytest":
https://github.com/DLR-RM/stable-baselines3/pull/13/checks?check_run_id=672562856

tests/test_spaces.py::test_identity_multidiscrete[A2C]
  /home/runner/work/stable-baselines3/stable-baselines3/stable_baselines3/a2c/a2c.py:121: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
    value_loss = F.mse_loss(rollout_data.returns, values)

tests/test_spaces.py::test_identity_multidiscrete[PPO]
  /home/runner/work/stable-baselines3/stable-baselines3/stable_baselines3/ppo/ppo.py:253: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
    value_loss = F.mse_loss(rollout_data.returns, values_pred)

@rolandgvc
Copy link
Contributor Author

yes, I was working on that now

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, that's a good work =)
Some minor comments/improvements here and there ;)

Once you addressed them, I will help you to write tests for SAC/TD3

stable_baselines3/common/distributions.py Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/distributions.py Outdated Show resolved Hide resolved
stable_baselines3/common/preprocessing.py Outdated Show resolved Hide resolved
stable_baselines3/ppo/policies.py Outdated Show resolved Hide resolved
stable_baselines3/ppo/policies.py Outdated Show resolved Hide resolved
tests/test_distributions.py Outdated Show resolved Hide resolved
tests/test_spaces.py Outdated Show resolved Hide resolved
@rolandgvc
Copy link
Contributor Author

Ready when you are 👍

@araffin
Copy link
Member

araffin commented May 17, 2020

I will try to take a look today or tomorrow ;)

@araffin araffin self-requested a review May 18, 2020 09:43
@araffin
Copy link
Member

araffin commented May 18, 2020

@rolandgvc could you give me access to your fork? I could not push some changes...

See https://github.com/DLR-RM/stable-baselines3/tree/pull_13

@rolandgvc
Copy link
Contributor Author

You mean create pull request?
Screenshot 2020-05-18 at 13 11 36

@araffin
Copy link
Member

araffin commented May 18, 2020

You mean create pull request?

I meant write access to your fork, so I can push my changes. (normally, you should have ticked "allow edits from maintainers" when creating this PR).

@araffin
Copy link
Member

araffin commented May 18, 2020

Capture d’écran de 2020-05-18 14-22-38

never mind, I created a PR in your fork...

@rolandgvc
Copy link
Contributor Author

But I did
Screenshot 2020-05-18 at 13 23 35

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you very much for the good work =)

@araffin araffin merged commit 91adefd into DLR-RM:master May 18, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants