diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ee78c68835f..c79e4f42c49 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer. OnlineDTActor RSSMPosterior RSSMPrior + set_recurrent_mode + recurrent_mode Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index feef2595422..e52584c4ac4 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.single_action_spec.space.low, - "high": env.single_action_spec.space.high, + "low": env.action_spec_unbatched.space.low, + "high": env.action_spec_unbatched.space.high, }, return_log_prob=True, ) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index affe591be77..a0cea48b510 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.single_action_spec.space.low.to(device), - "high": proof_environment.single_action_spec.space.high.to(device), + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), } # Define input keys diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 3d5fdf6423e..645bc806265 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.single_action_spec.space.low.to(device), - "high": proof_environment.single_action_spec.space.high.to(device), + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, "safe_tanh": True, } diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index f5c3a9ea3fa..51134b6828d 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg): def make_cql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model - action_spec = train_env.single_action_spec + action_spec = train_env.action_spec_unbatched actor_net, q_net = make_cql_modules_state(model_cfg, eval_env) in_keys = ["observation"] diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 151f8d40408..483bf257c63 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -147,7 +147,7 @@ def make_crossQ_agent(cfg, train_env, device): """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.single_action_spec + action_spec = train_env.action_spec_unbatched actor_net_kwargs = { "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 79d78db89da..7f905c72366 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -393,7 +393,7 @@ def make_dt_model(cfg): make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) - action_spec = proof_environment.single_action_spec + action_spec = proof_environment.action_spec_unbatched for key, value in proof_environment.observation_spec.items(): if key == "observation": state_dim = value.shape[-1] diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index fba4da253a5..63310113e98 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.single_action_spec.space.low, - "high": proof_environment.single_action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, "tanh_loc": False, } diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 8deb66406f0..ff84d0d8138 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -195,7 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model in_keys = ["observation"] - action_spec = train_env.single_action_spec + action_spec = train_env.action_spec_unbatched actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env) out_keys = ["loc", "scale"] diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index ec55a3aaf33..0a5372a62b0 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.single_action_spec, + spec=env.action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index f6a513717d0..d666d0d6982 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -91,13 +91,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "param")], out_keys=[env.action_key], distribution_class=TanhDelta, distribution_kwargs={ - "low": env.single_action_spec[("agents", "action")].space.low, - "high": env.single_action_spec[("agents", "action")].space.high, + "low": env.action_spec_unbatched[("agents", "action")].space.low, + "high": env.action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=False, ) @@ -105,7 +105,7 @@ def train(cfg: "DictConfig"): # noqa: F821 policy_explore = TensorDictSequential( policy, AdditiveGaussianModule( - spec=env.single_action_spec, + spec=env.action_spec_unbatched, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, device=cfg.train.device, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index 16acfb34e13..6f50c79b149 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.single_action_spec[("agents", "action")].space.low, - "high": env.single_action_spec[("agents", "action")].space.high, + "low": env.action_spec_unbatched[("agents", "action")].space.low, + "high": env.action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index 7026091f1b9..a1a3eb35618 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.single_action_spec, + spec=env.action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 26004ecbf3b..f0b11fe6b9c 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821 policy = ProbabilisticActor( module=policy_module, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.single_action_spec[("agents", "action")].space.low, - "high": env.single_action_spec[("agents", "action")].space.high, + "low": env.action_spec_unbatched[("agents", "action")].space.low, + "high": env.action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "logits")], out_keys=[env.action_key], distribution_class=OneHotCategorical @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821 actor_network=policy, qvalue_network=value_module, delay_qvalue=True, - action_spec=env.single_action_spec, + action_spec=env.action_spec_unbatched, ) loss_module.set_keys( state_action_value=("agents", "state_action_value"), @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821 qvalue_network=value_module, delay_qvalue=True, num_actions=env.action_spec.space.n, - action_space=env.single_action_spec, + action_space=env.action_spec_unbatched, ) loss_module.set_keys( action_value=("agents", "action_value"), diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index ee23e876e88..debc8f9e211 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.single_action_spec.space.low, - "high": proof_environment.single_action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, } # Define input keys diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 407cf8777c2..6c7a1b80fd7 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.single_action_spec.space.low, - "high": proof_environment.single_action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, "tanh_loc": False, } diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 03a8b57fa81..9953fcb3112 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -410,7 +410,7 @@ def make_redq_model( default_policy_scale = cfg.network.default_policy_scale gSDE = cfg.exploration.gSDE - action_spec = proof_environment.single_action_spec + action_spec = proof_environment.action_spec_unbatched if actor_net_kwargs is None: actor_net_kwargs = {} diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index e8770ee4685..d1dbb2db791 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -161,7 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.single_action_spec + action_spec = train_env.action_spec_unbatched actor_net_kwargs = { "num_cells": cfg.network.hidden_sizes, "out_features": 2 * action_spec.shape[-1], diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 718c0c3b87e..6cc10123edd 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1388,17 +1388,17 @@ def _make_specs(self): obs_spec_unlazy = consolidate_spec(obs_specs) action_specs = torch.stack(action_specs, dim=0) - self.unbatched_observation_spec = Composite( + self.observation_spec_unbatched = Composite( lazy=obs_spec_unlazy, state=Unbounded(shape=(64, 64, 3)), device=self.device, ) - self.single_action_spec = Composite( + self.action_spec_unbatched = Composite( lazy=action_specs, device=self.device, ) - self.unbatched_reward_spec = Composite( + self.reward_spec_unbatched = Composite( { "lazy": Composite( {"reward": Unbounded(shape=(self.n_nested_dim, 1))}, @@ -1407,7 +1407,7 @@ def _make_specs(self): }, device=self.device, ) - self.unbatched_done_spec = Composite( + self.done_spec_unbatched = Composite( { "lazy": Composite( { @@ -1423,19 +1423,6 @@ def _make_specs(self): device=self.device, ) - self.action_spec = self.single_action_spec.expand( - *self.batch_size, *self.single_action_spec.shape - ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape - ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape - ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape - ) - def get_agent_obs_spec(self, i): camera = Bounded(low=0, high=200, shape=(7, 7, 3)) vector_3d = Unbounded(shape=(3,)) @@ -1610,21 +1597,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): self.make_specs() - self.action_spec = self.single_action_spec.expand( - *self.batch_size, *self.single_action_spec.shape - ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape - ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape - ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape - ) - def make_specs(self): - self.unbatched_observation_spec = Composite( + self.observation_spec_unbatched = Composite( nested_1=Composite( observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)), shape=(self.nested_dim_1,), @@ -1642,7 +1616,7 @@ def make_specs(self): ), ) - self.single_action_spec = Composite( + self.action_spec_unbatched = Composite( nested_1=Composite( action=Categorical(n=2, shape=(self.nested_dim_1,)), shape=(self.nested_dim_1,), @@ -1654,7 +1628,7 @@ def make_specs(self): action=OneHot(n=2), ) - self.unbatched_reward_spec = Composite( + self.reward_spec_unbatched = Composite( nested_1=Composite( gift=Unbounded(shape=(self.nested_dim_1, 1)), shape=(self.nested_dim_1,), @@ -1666,7 +1640,7 @@ def make_specs(self): reward=Unbounded(shape=(1,)), ) - self.unbatched_done_spec = Composite( + self.done_spec_unbatched = Composite( nested_1=Composite( done=Categorical( n=2, diff --git a/test/test_cost.py b/test/test_cost.py index 1e157fd7a2f..598b9ba004d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,6 +47,7 @@ DistributionalQValueActor, OneHotCategorical, QValueActor, + recurrent_mode, SafeSequential, WorldModelWrapper, ) @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + def test_decorators(self): + class MyLoss(LossModule): + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + loss = MyLoss() + loss.forward(None) + loss.actor_loss(None) + loss.something_loss(None) + assert not recurrent_mode() + @pytest.mark.parametrize("expand_dim", [None, 2]) @pytest.mark.parametrize("compare_against", [True, False]) @pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion") diff --git a/test/test_env.py b/test/test_env.py index 05d8308494a..ab854a3b4be 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3512,18 +3512,18 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi def test_single_env_spec(): env = NestedCountingEnv(batch_size=[3, 1, 7]) - assert not env.single_full_action_spec.shape - assert not env.single_full_done_spec.shape - assert not env.single_input_spec.shape - assert not env.single_full_observation_spec.shape - assert not env.single_output_spec.shape - assert not env.single_full_reward_spec.shape - - assert env.single_action_spec.shape - assert env.single_reward_spec.shape - - assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape)) - assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape)) + assert not env.full_action_spec_unbatched.shape + assert not env.full_done_spec_unbatched.shape + assert not env.input_spec_unbatched.shape + assert not env.full_observation_spec_unbatched.shape + assert not env.output_spec_unbatched.shape + assert not env.full_reward_spec_unbatched.shape + + assert env.action_spec_unbatched.shape + assert env.reward_spec_unbatched.shape + + assert env.output_spec.is_in(env.output_spec_unbatched.zeros(env.shape)) + assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) if __name__ == "__main__": diff --git a/test/test_libs.py b/test/test_libs.py index 8284133e2da..b3ba8d54c3d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2253,7 +2253,9 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents): max_steps=n_rollout_samples, return_contiguous=False if env.het_specs else True, ) - assert env.single_full_action_spec.shape == env.unbatched_action_spec.shape, ( + assert ( + env.full_action_spec_unbatched.shape == env.unbatched_action_spec.shape + ), ( env.action_spec, env.batch_size, ) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ec9322500b4..d3b7b7850f4 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -36,6 +36,7 @@ OnlineDTActor, ProbabilisticActor, SafeModule, + set_recurrent_mode, TanhDelta, TanhNormal, ValueOperator, @@ -729,6 +730,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): lstm_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + lstm_module = LSTMModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + default_recurrent_mode=default_val, + ) + assert lstm_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert lstm_module.recurrent_mode + with set_recurrent_mode(False): + assert not lstm_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert lstm_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert lstm_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): lstm_module = LSTMModule( input_size=3, @@ -754,7 +780,8 @@ def test_python_cudnn(self): num_layers=2, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) obs = torch.rand(10, 20, 3) hidden0 = torch.rand(10, 20, 2, 12) @@ -1109,6 +1136,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): gru_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + default_recurrent_mode=default_val, + ) + assert gru_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert gru_module.recurrent_mode + with set_recurrent_mode(False): + assert not gru_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert gru_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert gru_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): gru_module = GRUModule( input_size=3, diff --git a/test/test_transforms.py b/test/test_transforms.py index 56a39218f5f..8b2ada8c93a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs", "is_init"], out_keys=["output", ("next", "rhs")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): return LSTMModule( @@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs_h", "rhs_c", "is_init"], out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_batch(self, batch_size: int = 2, sequence_length: int = 5): observation = torch.randn(batch_size, sequence_length + 1, 4) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 0b4dd03a636..d37aebb862f 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -15,9 +15,11 @@ import os import pickle import sys +import threading import time import traceback import warnings +from contextlib import nullcontext from copy import copy from distutils.util import strtobool from functools import wraps @@ -32,6 +34,11 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) @@ -827,3 +834,19 @@ def _make_ordinal_device(device: torch.device): if device.type == "mps" and device.index is None: return torch.device("mps", index=0) return device + + +class _ContextManager: + def __init__(self): + self._mode: Any | None = None + self._lock = threading.Lock() + + def get_mode(self) -> Any | None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + return self._mode + + def set_mode(self, type: Any | None) -> None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + self._mode = type diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index cf784f5659d..ed91b26d8da 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1493,65 +1493,125 @@ def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: return spec[idx] @property - def single_full_action_spec(self) -> Composite: + def full_action_spec_unbatched(self) -> Composite: """Returns the action spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_action_spec_unbatched.setter + def full_action_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_action_spec = spec + @property - def single_action_spec(self) -> TensorSpec: + def action_spec_unbatched(self) -> TensorSpec: """Returns the action spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.action_spec) + @action_spec_unbatched.setter + def action_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.action_spec = spec + @property - def single_full_observation_spec(self) -> Composite: + def full_observation_spec_unbatched(self) -> Composite: """Returns the observation spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_observation_spec_unbatched.setter + def full_observation_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_observation_spec = spec + @property - def single_observation_spec(self) -> Composite: + def observation_spec_unbatched(self) -> Composite: """Returns the observation spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.observation_spec) + @observation_spec_unbatched.setter + def observation_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.observation_spec = spec + @property - def single_full_reward_spec(self) -> Composite: + def full_reward_spec_unbatched(self) -> Composite: """Returns the reward spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_reward_spec_unbatched.setter + def full_reward_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_reward_spec = spec + @property - def single_reward_spec(self) -> TensorSpec: + def reward_spec_unbatched(self) -> TensorSpec: """Returns the reward spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.reward_spec) + @reward_spec_unbatched.setter + def reward_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.reward_spec = spec + @property - def single_full_done_spec(self) -> Composite: + def full_done_spec_unbatched(self) -> Composite: """Returns the done spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_done_spec_unbatched.setter + def full_done_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_done_spec = spec + @property - def single_done_spec(self) -> TensorSpec: + def done_spec_unbatched(self) -> TensorSpec: """Returns the done spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.done_spec) + @done_spec_unbatched.setter + def done_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.done_spec = spec + @property - def single_output_spec(self) -> Composite: + def output_spec_unbatched(self) -> Composite: """Returns the output spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.output_spec) + @output_spec_unbatched.setter + def output_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.output_spec = spec + @property - def single_input_spec(self) -> Composite: + def input_spec_unbatched(self) -> Composite: """Returns the input spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.input_spec) + @input_spec_unbatched.setter + def input_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.input_spec = spec + @property - def single_full_state_spec(self) -> Composite: + def full_state_spec_unbatched(self) -> Composite: """Returns the state spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_state_spec) + @full_state_spec_unbatched.setter + def full_state_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_state_spec = spec + @property - def single_state_spec(self) -> TensorSpec: + def state_spec_unbatched(self) -> TensorSpec: """Returns the state spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.state_spec) + @state_spec_unbatched.setter + def state_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.state_spec = spec + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e02c88c5330..7bdd25591cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7411,7 +7411,8 @@ class BurnInTransform(Transform): ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], - ... ).set_recurrent_mode(True) + ... default_recurrent_mode=True, + ... ) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4cb6366f817..edf90a4e85b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -80,10 +80,12 @@ QValueActor, QValueHook, QValueModule, + recurrent_mode, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, SafeSequential, + set_recurrent_mode, TanhModule, ValueOperator, VmapModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 202f84fd173..3fb1559833a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -34,6 +34,15 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule +from .rnn import ( + GRU, + GRUCell, + GRUModule, + LSTM, + LSTMCell, + LSTMModule, + recurrent_mode, + set_recurrent_mode, +) from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 6a99e85812b..f4ceb648665 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional, Tuple +import typing +import warnings +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F @@ -18,6 +20,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase +from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( _inv_pad_sequence, @@ -376,6 +379,9 @@ class LSTMModule(ModuleBase): device (torch.device or compatible): the device of the module. lstm (torch.nn.LSTM, optional): an LSTM instance to be wrapped. Exclusive with other nn.LSTM arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -451,6 +457,7 @@ def __init__( out_keys=None, device=None, lstm=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if lstm is not None: @@ -524,7 +531,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> LSTMModule: """Transforms the LSTM layer in its python-based version. @@ -647,12 +654,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -662,7 +672,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -692,7 +702,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The lstm.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = LSTMModule(lstm=self.lstm, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1155,6 +1171,9 @@ class GRUModule(ModuleBase): device (torch.device or compatible): the device of the module. gru (torch.nn.GRU, optional): a GRU instance to be wrapped. Exclusive with other nn.GRU arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -1256,6 +1275,7 @@ def __init__( out_keys=None, device=None, gru=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if gru is not None: @@ -1326,7 +1346,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> GRUModule: """Transforms the GRU layer in its python-based version. @@ -1444,12 +1464,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -1459,7 +1482,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -1488,7 +1511,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The gru.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1598,3 +1627,57 @@ def _gru( ) out = [y, hidden] return tuple(out) + + +# Recurrent mode manager +recurrent_mode_state_manager = _ContextManager() + + +def recurrent_mode() -> bool | None: + """Returns the current sampling type.""" + return recurrent_mode_state_manager.get_mode() + + +class set_recurrent_mode(_DecoratorContextManager): + """Context manager for setting RNNs recurrent mode. + + Args: + mode (bool, "recurrent" or "stateful"): the recurrent mode to be used within the context manager. + `"recurrent"` leads to `mode=True` and `"stateful"` leads to `mode=False`. + An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise + it is assumed that each data element in a tensordict is independent of the others. + The default value of this context manager is ``True``. + The default recurrent mode is ``None``, i.e., the default recurrent mode of the RNN is used + (see :class:`~torchrl.modules.LSTMModule` and :class:`~torchrl.modules.GRUModule` constructors). + + .. seealso:: :class:`~torchrl.modules.recurrent_mode``. + + .. note:: All of TorchRL methods are decorated with ``set_recurrent_mode(True)`` by default. + + """ + + def __init__( + self, mode: bool | typing.Literal["recurrent", "sequential"] | None = True + ) -> None: + super().__init__() + if isinstance(mode, str): + if mode.lower() in ("recurrent",): + mode = True + elif mode.lower() in ("sequential",): + mode = False + else: + raise ValueError( + f"Unsupported recurrent mode. Must be a bool, or one of {('recurrent', 'sequential')}" + ) + self.mode = mode + + def clone(self) -> set_recurrent_mode: + # override this method if your children class takes __init__ parameters + return type(self)(self.mode) + + def __enter__(self) -> None: + self.prev = recurrent_mode_state_manager.get_mode() + recurrent_mode_state_manager.set_mode(self.mode) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + recurrent_mode_state_manager.set_mode(self.prev) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 57310a5fc3d..d54671f569b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -21,6 +21,7 @@ from torch.nn import Parameter from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import set_recurrent_mode from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase @@ -46,7 +47,9 @@ def _updater_check_forward_prehook(module, *args, **kwargs): def _forward_wrapper(func): @functools.wraps(func) def new_forward(self, *args, **kwargs): - with set_exploration_type(self.deterministic_sampling_mode): + with set_exploration_type(self.deterministic_sampling_mode), set_recurrent_mode( + True + ): return func(self, *args, **kwargs) return new_forward @@ -56,6 +59,9 @@ class _LossMeta(abc.ABCMeta): def __init__(cls, name, bases, attr_dict): super().__init__(name, bases, attr_dict) cls.forward = _forward_wrapper(cls.forward) + for name, value in cls.__dict__.items(): + if not name.startswith("_") and name.endswith("loss"): + setattr(cls, name, _forward_wrapper(value)) class LossModule(TensorDictModuleBase, metaclass=_LossMeta): diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 2e6a111f258..a0373ba4b46 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -431,8 +431,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.single_action_spec.space.low, - "high": env.single_action_spec.space.high, + "low": env.action_spec_unbatched.space.low, + "high": env.action_spec_unbatched.space.high, }, return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 8931f483384..58c47f68321 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -317,7 +317,7 @@ # # We can now put things together in a :class:`~tensordict.nn.TensorDictSequential` # -stoch_policy = Seq(feature, lstm, mlp, qval) +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # DQN being a deterministic algorithm, exploration is a crucial part of it. @@ -330,7 +330,7 @@ annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 ) stoch_policy = TensorDictSequential( - stoch_policy, + policy, exploration_module, ) @@ -338,20 +338,17 @@ # Using the model for the loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The model as we've built it is well equipped to be used in sequential settings. +# The model as we've built it is well-equipped to be used in sequential settings. # However, the class :class:`torch.nn.LSTM` can use a cuDNN-optimized backend # to run the RNN sequence faster on GPU device. We would not want to miss # such an opportunity to speed up our training loop! -# To use it, we just need to tell the LSTM module to run on "recurrent-mode" -# when used by the loss. -# As we'll usually want to have two copies of the LSTM module, we do this by -# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that -# will return a new instance of the LSTM (with shared weights) that will -# assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) - -###################################################################### +# By default, torchrl losses will use this when executing any +# :class:`~torchrl.modules.LSTMModule` or :class:`~torchrl.modules.GRUModule` +# forward call. If you need to control this manually, the RNN modules are sensitive +# to a context manager/decorator, :class:`~torchrl.modules.set_recurrent_mode`, +# that handles the behaviour of the underlying RNN module. +# # Because we still have a couple of uninitialized parameters we should # initialize them before creating an optimizer and such. # diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 0a4390abdfc..48dd8723ffc 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -265,10 +265,6 @@ in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", "hidden0", "hidden1"], ) -##################################### -# We set the recurrent mode to ``False`` to allow the module to read inputs one-by-one and not in batch. -# -lstm = lstm.set_recurrent_mode(False) ##################################### # If the LSTM module is not python based but CuDNN (:class:`~torch.nn.LSTM`), the :meth:`~torchrl.modules.LSTMModule.make_python_based` diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index 4ee3ae02a9f..a7bd74a4deb 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -486,8 +486,8 @@ out_keys=[(group, "action")], distribution_class=TanhDelta, distribution_kwargs={ - "low": env.single_full_action_spec[group, "action"].space.low, - "high": env.single_full_action_spec[group, "action"].space.high, + "low": env.full_action_spec_unbatched[group, "action"].space.low, + "high": env.full_action_spec_unbatched[group, "action"].space.high, }, return_log_prob=False, ) diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 700eb1634a9..e2ca3f6ecd8 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -445,13 +445,13 @@ policy = ProbabilisticActor( module=policy_module, - spec=env.single_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.single_action_spec[env.action_key].space.low, - "high": env.single_action_spec[env.action_key].space.high, + "low": env.action_spec_unbatched[env.action_key].space.low, + "high": env.action_spec_unbatched[env.action_key].space.high, }, return_log_prob=True, log_prob_key=("agents", "sample_log_prob"),