-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce interactive policies to gather data from a user #776
Conversation
assert isinstance(action_space, gym.spaces.Discrete) | ||
assert len(action_names) == len(action_keys) == action_space.n | ||
# Names and keys should be unique. | ||
assert len(set(action_names)) == len(set(action_keys)) == action_space.n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have these both as sequences rather than a dictionary mapping from action key to action name (or vice-versa)? This would enforce the length the same, and uniqueness on the keys, so one would only need to check that len(the_dict) == action_space.n
and that the values (action names) are unique.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we would need an ordered dictionary so perhaps that's a reason against, although all dictionaries are ordered in Python since 3.6.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I did this because of ordering, and had doubts about enforcing OrderedDict
type, but now I think I agree it's more elegant so changing into OrderedDict
.
env.seed(0) | ||
|
||
action_names = env.envs[0].get_action_meanings() | ||
names_to_keys = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's only a small finite number of legal actions in Atari, so we could define a more comprehensive version of these in a constant somewhere (or even subclass ImageObsDiscreteInteractivePolicy to handle this directly) rather than it having to live in an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks for suggestion!
Co-authored-by: Jason Hoelscher-Obermaier <jas-ho@users.noreply.github.com>
@@ -30,9 +30,12 @@ def _paths_to_strs(x: Iterable[pathlib.Path]) -> Sequence[str]: | |||
EXAMPLES_DIR = THIS_DIR / ".." / "examples" | |||
TUTORIALS_DIR = THIS_DIR / ".." / "docs" / "tutorials" | |||
|
|||
SH_PATHS = _paths_to_strs(EXAMPLES_DIR.glob("*.sh")) | |||
EXCLUDED_EXAMPLE_FILES = ["train_dagger_atari_interactive_policy.py"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably one could think about an alternative where we mock parts of the example script etc. However, it does not seem to be super useful, since we have unit tests that check analogous things that this mocked version would check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Looks nearly ready -- only major sugestion is to add some more tests (covering AtariInteractivePolicy), others are pretty minor suggestions.
@@ -0,0 +1,41 @@ | |||
"""Training DAgger with an interactive policy that queries the user for actions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a matplotlib GUI backend isn't installed, it'll fail with a somewhat cryptic error:
fig.show()
and indeed no figure displays.
Installing the relevant backend seems out-of-scope for this project. But might want to check if the backend is interactive (I think plt.isinteractive() checks for this) and warn if not with link to relevant docs e.g. https://matplotlib.org/stable/users/explain/backends.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please check again what message you are getting and if this is an error or a warning? In my case, when I set a non-GUI backend like Agg
, I get a warning like this (and the execution continues):
UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
fig.show()
Actually, plt.isinteractive()
checks for the interactive mode; GUI backends like MacOsX can be in both modes (and we do not need to turn on interactive for it to work). What we actually would like is to check if the backend is "GUI" or "non-GUI" but from a simple search, it does not seem like there is a nice way to do it (rather than check with some white-list of backends). Given that, and the fact that the message I listed above is not that bad, I'd keep this as-is for now. Alternatively, we could opt for throwing an error/assert instead of warning, but again this would require a white-list of backends. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For additional context: on my laptop, the example runs nicely although the mode is not interactive by default.
self.action_key_to_index = {k: i for i, k in enumerate(action_keys)} | ||
self.action_keys_names = action_keys_names | ||
self.action_key_to_index = { | ||
k: i for i, k in enumerate(action_keys_names.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iterating over a dict gives you the keys by default (you can leave as-is if you want to be explicit about it)
k: i for i, k in enumerate(action_keys_names.keys()) | |
k: i for i, k in enumerate(action_keys_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, in this case I slightly prefer to keep it.
import abc | ||
from typing import Optional, List | ||
import collections | ||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our style guide allows importing types directly from typing (i.e. from typing import Optional
is permissible) although it's not obligatory -- fine to use this style if you prefer. https://google.github.io/styleguide/pyguide.html#2241-exemptions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good to know!
with mock.patch("builtins.input", mock_input_invalid_then_valid()): | ||
interactive_policy.predict(obs) | ||
stdout = capsys.readouterr().out | ||
assert "Your choice" in stdout and "Invalid" in stdout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for DiscreteInteractivePolicy
looks great. We're not testing AtariInteractivePolicy
at all though. It's pretty simple granted but might still be worth testing, even if just a simple smoke test (it runs, if we feed in a key corresponding to "FIRE" we get the correct action back, etc).
Co-authored-by: Adam Gleave <adam@gleave.me>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…tibleAI#776) * Pin huggingface_sb3 version. * Properly specify the compatible seals version so it does not auto-upgrade to 0.2. * Make random_mdp test deterministic by seeding the environment. * WIP: Introduce interactive policies to gather data from a user * Addressing remarks from review * fixes * fix types * formatting * Dummy commit to acknowledge co-authorship. Co-authored-by: Jason Hoelscher-Obermaier <jas-ho@users.noreply.github.com> * Exclude interactive example from running during tests * formatting * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Adressing further suggestions from review * formatting * formatting --------- Co-authored-by: Maximilian Ernestus <maximilian@ernestus.de> Co-authored-by: Jason Hoelscher-Obermaier <jas-ho@users.noreply.github.com> Co-authored-by: Adam Gleave <adam@gleave.me>
Description
This PR introduces interactive policies that query the user for actions, as requested in #701.
Such policies can be used e.g. in Behavioral cloning or DAgger.
An example showing the use for Atari is included.
Acknowledgement: tests were heavily based on the ones from #768 by @jas-ho.
Testing
pytest tests/policies/test_interactive.py
to run unit tests.python examples/train_dagger_atari_interactive_policy.py
to run the interactive demo.