diff --git a/test/test_cost.py b/test/test_cost.py index 084c69d1970..65e0c41af73 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -3213,13 +3213,20 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): class TestA2C(LossModuleTestBase): seed = 0 - def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + ): # Actor action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + module = SafeModule(net, in_keys=[observation_key], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, in_keys=["loc", "scale"], @@ -3229,12 +3236,18 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): return actor.to(device) def _create_mock_value( - self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", ): module = nn.Linear(obs_dim, 1) value = ValueOperator( module=module, - in_keys=["observation"], + in_keys=[observation_key], out_keys=out_keys, ) return value.to(device) @@ -3248,6 +3261,9 @@ def _create_seq_mock_data_a2c( atoms=None, device="cpu", action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -3267,11 +3283,11 @@ def _create_seq_mock_data_a2c( td = TensorDict( batch_size=(batch, T), source={ - "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + observation_key: obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "next": { - "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), - "done": done, - "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + done_key: done, + reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), @@ -3443,6 +3459,8 @@ def test_a2c_tensordict_keys(self, td_est): "value_target": "value_target", "value": "state_value", "action": "action", + "reward": "reward", + "done": "done", } self.tensordict_keys_test( @@ -3459,6 +3477,8 @@ def test_a2c_tensordict_keys(self, td_est): "advantage": ("advantage", "advantage_test"), "value_target": ("value_target", "value_target_test"), "value": ("value", "value_state_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3471,8 +3491,15 @@ def test_a2c_tensordict_keys_run(self, device): value_target_key = "value_target_test" value_key = "state_value_test" action_key = "action_test" + reward_key = "reward_test" + done_key = ("done", "test") - td = self._create_seq_mock_data_a2c(device=device, action_key=action_key) + td = self._create_seq_mock_data_a2c( + device=device, + action_key=action_key, + reward_key=reward_key, + done_key=done_key, + ) actor = self._create_mock_actor(device=device) value = self._create_mock_value(device=device, out_keys=[value_key]) @@ -3486,6 +3513,8 @@ def test_a2c_tensordict_keys_run(self, device): advantage=advantage_key, value_target=value_target_key, value=value_key, + reward=reward_key, + done=done_key, ) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( @@ -3493,6 +3522,8 @@ def test_a2c_tensordict_keys_run(self, device): value_target=value_target_key, value=value_key, action=action_key, + reward=reward_key, + done=done_key, ) advantage(td) @@ -3525,6 +3556,42 @@ def test_a2c_tensordict_keys_run(self, device): # test reset loss_fn.reset() + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_key): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(observation_key=observation_key) + value = self._create_mock_value(observation_key=observation_key) + td = self._create_seq_mock_data_a2c( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + ) + + loss = A2CLoss(actor, value) + loss.set_keys(action=action_key, reward=reward_key, done=done_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + action_key: td.get(action_key), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + loss_val = loss(**kwargs) + loss_val_td = loss(td) + + torch.testing.assert_close(loss_val_td.get("loss_objective"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_critic"), loss_val[1]) + # don't test entropy and loss_entropy, since they depend on a random sample + # from distribution + assert len(loss_val) == 4 + class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index f6953cb7985..ca2dbafa88a 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -7,7 +7,7 @@ from typing import Tuple import torch -from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -67,6 +67,88 @@ class A2CLoss(LossModule): The default is :class:`~torchrl.objectives.value.GAE` with hyperparameters dictated by :func:`~torchrl.objectives.utils.default_value_kwargs`. + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.a2c import A2CLoss + >>> from tensordict.tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> module = nn.Linear(n_obs, 1) + >>> value = ValueOperator( + ... module=module, + ... in_keys=["observation"]) + >>> loss = A2CLoss(actor, value, loss_critic_type="l2") + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic. + The return value is a tuple of tensors in the following order: + ``["loss_objective"]`` + + ``["loss_critic"]`` if critic_coef is not None + + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.a2c import A2CLoss + >>> _ = torch.manual_seed(42) + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> module = nn.Linear(n_obs, 1) + >>> value = ValueOperator( + ... module=module, + ... in_keys=["observation"]) + >>> loss = A2CLoss(actor, value, loss_critic_type="l2") + >>> batch = [2, ] + >>> loss_val = loss( + ... observation = torch.randn(*batch, n_obs), + ... action = spec.rand(batch), + ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward = torch.randn(*batch, 1)) + >>> loss_val + (tensor(1.7593, grad_fn=), tensor(0.2344, grad_fn=), tensor(1.5480), tensor(-0.0155, grad_fn=)) """ @dataclass @@ -85,12 +167,19 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"state_value"``. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. """ advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" action: NestedKey = "action" + reward: NestedKey = "reward" + done: NestedKey = "done" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE @@ -141,9 +230,11 @@ def __init__( def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( - advantage=self._tensor_keys.advantage, - value_target=self._tensor_keys.value_target, - value=self._tensor_keys.value, + advantage=self.tensor_keys.advantage, + value_target=self.tensor_keys.value_target, + value=self.tensor_keys.value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, ) def reset(self) -> None: @@ -198,6 +289,29 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: ) return self.critic_coef * loss_value + @property + def in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ] + keys.extend(self.actor.in_keys) + if self.critic_coef: + keys.extend(self.critic.in_keys) + return list(set(keys)) + + @property + def out_keys(self): + outs = ["loss_objective"] + if self.critic_coef: + outs.append("loss_critic") + if self.entropy_bonus: + outs.append("entropy") + outs.append("loss_entropy") + return outs + + @dispatch() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) @@ -243,5 +357,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "advantage": self.tensor_keys.advantage, "value": self.tensor_keys.value, "value_target": self.tensor_keys.value_target, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, } self._value_estimator.set_keys(**tensor_keys)