-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce interactive policies to gather data from a user #776
Changes from 10 commits
b8d1616
09c5f2f
4872ceb
b0efd61
6a9389e
8cb822a
3a1d10f
f63bf5e
79749dd
f6ebbc2
a3864c3
d0e0ecd
10936f7
73a33da
b94faf4
9d93854
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Training DAgger with an interactive policy that queries the user for actions. | ||
|
||
Note that this is a toy example that does not lead to training a reasonable policy. | ||
""" | ||
|
||
import tempfile | ||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import gym | ||
import numpy as np | ||
from stable_baselines3.common import vec_env | ||
|
||
from imitation.algorithms import bc, dagger | ||
from imitation.policies import interactive | ||
|
||
if __name__ == "__main__": | ||
rng = np.random.default_rng(0) | ||
|
||
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) | ||
env.seed(0) | ||
|
||
expert = interactive.AtariInteractivePolicy(env) | ||
|
||
bc_trainer = bc.BC( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
rng=rng, | ||
) | ||
|
||
with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: | ||
dagger_trainer = dagger.SimpleDAggerTrainer( | ||
venv=env, | ||
scratch_dir=tmpdir, | ||
expert_policy=expert, | ||
bc_trainer=bc_trainer, | ||
rng=rng, | ||
) | ||
dagger_trainer.train( | ||
total_timesteps=20, | ||
rollout_round_min_episodes=1, | ||
rollout_round_min_timesteps=10, | ||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,152 @@ | ||||||
"""Interactive policies that query the user for actions.""" | ||||||
|
||||||
import abc | ||||||
import collections | ||||||
import typing | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our style guide allows importing types directly from typing (i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, good to know! |
||||||
|
||||||
import gym | ||||||
import matplotlib.pyplot as plt | ||||||
import numpy as np | ||||||
from stable_baselines3.common import vec_env | ||||||
|
||||||
import imitation.policies.base as base_policies | ||||||
from imitation.util import util | ||||||
|
||||||
|
||||||
class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC): | ||||||
"""Abstract class for interactive policies with discrete actions. | ||||||
|
||||||
For each query, the observation is rendered and then the action is provided | ||||||
as a keyboard input. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
observation_space: gym.Space, | ||||||
action_space: gym.Space, | ||||||
action_keys_names: collections.OrderedDict, | ||||||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
clear_screen_on_query: bool = True, | ||||||
): | ||||||
"""Builds DiscreteInteractivePolicy. | ||||||
|
||||||
Args: | ||||||
observation_space: Observation space. | ||||||
action_space: Action space. | ||||||
action_keys_names: `OrderedDict` containing pairs (key, name) for every | ||||||
action, where key will be used in the console interface, and name | ||||||
is a semantic action name. | ||||||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
clear_screen_on_query: If `True`, console will be cleared on every query. | ||||||
""" | ||||||
super().__init__( | ||||||
observation_space=observation_space, | ||||||
action_space=action_space, | ||||||
) | ||||||
|
||||||
assert isinstance(action_space, gym.spaces.Discrete) | ||||||
assert ( | ||||||
len(action_keys_names) | ||||||
== len(set(action_keys_names.values())) | ||||||
== action_space.n | ||||||
) | ||||||
|
||||||
self.action_keys_names = action_keys_names | ||||||
self.action_key_to_index = { | ||||||
k: i for i, k in enumerate(action_keys_names.keys()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Iterating over a dict gives you the keys by default (you can leave as-is if you want to be explicit about it)
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, in this case I slightly prefer to keep it. |
||||||
} | ||||||
self.clear_screen_on_query = clear_screen_on_query | ||||||
|
||||||
def _choose_action(self, obs: np.ndarray) -> np.ndarray: | ||||||
if self.clear_screen_on_query: | ||||||
util.clear_screen() | ||||||
|
||||||
context = self._render(obs) | ||||||
key = self._get_input_key() | ||||||
self._clean_up(context) | ||||||
|
||||||
return np.array([self.action_key_to_index[key]]) | ||||||
|
||||||
def _get_input_key(self) -> str: | ||||||
"""Obtains input key for action selection.""" | ||||||
print( | ||||||
"Please select an action. Possible choices in [ACTION_NAME:KEY] format:", | ||||||
", ".join([f"{n}:{k}" for k, n in self.action_keys_names.items()]), | ||||||
) | ||||||
|
||||||
key = input("Your choice (enter key):") | ||||||
while key not in self.action_keys_names.keys(): | ||||||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
key = input("Invalid key, please try again! Your choice (enter key):") | ||||||
|
||||||
return key | ||||||
|
||||||
@abc.abstractmethod | ||||||
def _render(self, obs: np.ndarray) -> typing.Optional[object]: | ||||||
"""Renders an observation, optionally returns a context for later cleanup.""" | ||||||
|
||||||
def _clean_up(self, context: object) -> None: | ||||||
"""Cleans up after the input has been captured, e.g. stops showing the image.""" | ||||||
pass | ||||||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): | ||||||
"""DiscreteInteractivePolicy that renders image observations.""" | ||||||
|
||||||
def _render(self, obs: np.ndarray) -> plt.Figure: | ||||||
img = self._prepare_obs_image(obs) | ||||||
|
||||||
fig, ax = plt.subplots() | ||||||
ax.imshow(img, cmap="gray", vmin=0, vmax=255) # cmap is ignored for RGB images. | ||||||
ax.axis("off") | ||||||
fig.show() | ||||||
|
||||||
return fig | ||||||
|
||||||
def _clean_up(self, context: plt.Figure) -> None: | ||||||
plt.close(context) | ||||||
|
||||||
def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: | ||||||
"""Applies any required observation processing to get an image to show.""" | ||||||
return obs | ||||||
|
||||||
|
||||||
ATARI_ACTION_NAMES_TO_KEYS = { | ||||||
"NOOP": "1", | ||||||
"FIRE": "2", | ||||||
"UP": "w", | ||||||
"RIGHT": "d", | ||||||
"LEFT": "a", | ||||||
"DOWN": "x", | ||||||
"UPRIGHT": "e", | ||||||
"UPLEFT": "q", | ||||||
"DOWNRIGHT": "c", | ||||||
"DOWNLEFT": "z", | ||||||
"UPFIRE": "t", | ||||||
"RIGHTFIRE": "h", | ||||||
"LEFTFIRE": "f", | ||||||
"DOWNFIRE": "b", | ||||||
"UPRIGHTFIRE": "y", | ||||||
"UPLEFTFIRE": "r", | ||||||
"DOWNRIGHTFIRE": "n", | ||||||
"DOWNLEFTFIRE": "v", | ||||||
} | ||||||
|
||||||
|
||||||
class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy): | ||||||
"""Interactive policy for Atari environments.""" | ||||||
|
||||||
def __init__(self, env: typing.Union[gym.Env, vec_env.VecEnv], *args, **kwargs): | ||||||
"""Builds AtariInteractivePolicy.""" | ||||||
action_names = ( | ||||||
env.get_action_meanings() | ||||||
if isinstance(env, gym.Env) | ||||||
else env.env_method("get_action_meanings", indices=[0])[0] | ||||||
) | ||||||
action_keys_names = collections.OrderedDict( | ||||||
[(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names], | ||||||
) | ||||||
super().__init__( | ||||||
env.observation_space, | ||||||
env.action_space, | ||||||
action_keys_names, | ||||||
*args, | ||||||
**kwargs, | ||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Tests interactive policies.""" | ||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import collections | ||
from unittest import mock | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
from stable_baselines3.common import vec_env | ||
|
||
from imitation.policies import interactive | ||
|
||
ENVS = [ | ||
"CartPole-v0", | ||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
|
||
|
||
class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy): | ||
"""DiscreteInteractivePolicy with no rendering.""" | ||
|
||
def _render(self, obs: np.ndarray) -> None: | ||
pass | ||
|
||
|
||
def _get_interactive_policy(env: vec_env.VecEnv): | ||
num_actions = env.action_space.n | ||
action_keys_names = collections.OrderedDict( | ||
[(f"k{i}", f"n{i}") for i in range(num_actions)], | ||
) | ||
interactive_policy = NoRenderingDiscreteInteractivePolicy( | ||
env.observation_space, | ||
env.action_space, | ||
action_keys_names, | ||
) | ||
return interactive_policy | ||
|
||
|
||
@pytest.mark.parametrize("env_name", ENVS) | ||
def test_interactive_policy(env_name: str): | ||
michalzajac-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Test if correct actions are selected, as specified by input keys.""" | ||
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)]) | ||
env.seed(0) | ||
|
||
interactive_policy = _get_interactive_policy(env) | ||
action_keys = list(interactive_policy.action_keys_names.keys()) | ||
|
||
obs = env.reset() | ||
done = np.array([False]) | ||
|
||
class mock_input: | ||
def __init__(self): | ||
self.index = 0 | ||
|
||
def __call__(self, _): | ||
# Sometimes insert incorrect keys, which should get ignored by the policy. | ||
if np.random.uniform() < 0.5: | ||
return "invalid" | ||
key = action_keys[self.index] | ||
self.index = (self.index + 1) % len(action_keys) | ||
return key | ||
|
||
with mock.patch("builtins.input", mock_input()): | ||
requested_action = 0 | ||
while not done.all(): | ||
action, _ = interactive_policy.predict(obs) | ||
assert isinstance(action, np.ndarray) | ||
assert all(env.action_space.contains(a) for a in action) | ||
assert action[0] == requested_action | ||
|
||
obs, reward, done, info = env.step(action) | ||
assert isinstance(obs, np.ndarray) | ||
assert all(env.observation_space.contains(o) for o in obs) | ||
assert isinstance(reward, np.ndarray) | ||
assert isinstance(done, np.ndarray) | ||
|
||
requested_action = (requested_action + 1) % len(action_keys) | ||
|
||
|
||
@pytest.mark.parametrize("env_name", ENVS) | ||
def test_interactive_policy_input_validity(capsys, env_name: str): | ||
"""Test if appropriate feedback is given on the validity of the input.""" | ||
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)]) | ||
env.seed(0) | ||
|
||
interactive_policy = _get_interactive_policy(env) | ||
action_keys = list(interactive_policy.action_keys_names.keys()) | ||
|
||
# Valid input key case | ||
obs = env.reset() | ||
|
||
def mock_input_valid(prompt): | ||
print(prompt) | ||
return action_keys[0] | ||
|
||
with mock.patch("builtins.input", mock_input_valid): | ||
interactive_policy.predict(obs) | ||
stdout = capsys.readouterr().out | ||
assert "Your choice" in stdout and "Invalid" not in stdout | ||
|
||
# First invalid input key, then valid | ||
obs = env.reset() | ||
|
||
class mock_input_invalid_then_valid: | ||
def __init__(self): | ||
self.return_valid = False | ||
|
||
def __call__(self, prompt): | ||
print(prompt) | ||
if self.return_valid: | ||
return action_keys[0] | ||
self.return_valid = True | ||
return "invalid" | ||
|
||
with mock.patch("builtins.input", mock_input_invalid_then_valid()): | ||
interactive_policy.predict(obs) | ||
stdout = capsys.readouterr().out | ||
assert "Your choice" in stdout and "Invalid" in stdout | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a matplotlib GUI backend isn't installed, it'll fail with a somewhat cryptic error:
and indeed no figure displays.
Installing the relevant backend seems out-of-scope for this project. But might want to check if the backend is interactive (I think plt.isinteractive() checks for this) and warn if not with link to relevant docs e.g. https://matplotlib.org/stable/users/explain/backends.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please check again what message you are getting and if this is an error or a warning? In my case, when I set a non-GUI backend like
Agg
, I get a warning like this (and the execution continues):Actually,
plt.isinteractive()
checks for the interactive mode; GUI backends like MacOsX can be in both modes (and we do not need to turn on interactive for it to work). What we actually would like is to check if the backend is "GUI" or "non-GUI" but from a simple search, it does not seem like there is a nice way to do it (rather than check with some white-list of backends). Given that, and the fact that the message I listed above is not that bad, I'd keep this as-is for now. Alternatively, we could opt for throwing an error/assert instead of warning, but again this would require a white-list of backends. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For additional context: on my laptop, the example runs nicely although the mode is not interactive by default.