Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
2 parents c40a365 + 0c3fad6 commit b59d5de
Show file tree
Hide file tree
Showing 36 changed files with 367 additions and 134 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer.
OnlineDTActor
RSSMPosterior
RSSMPrior
set_recurrent_mode
recurrent_mode

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.single_action_spec.space.low,
"high": env.single_action_spec.space.high,
"low": env.action_spec_unbatched.space.low,
"high": env.action_spec_unbatched.space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": True,
}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg):
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched

actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
in_keys = ["observation"]
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_dt_model(cfg):
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
)

action_spec = proof_environment.single_action_spec
action_spec = proof_environment.action_spec_unbatched
for key, value in proof_environment.observation_spec.items():
if key == "observation":
state_dim = value.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

in_keys = ["observation"]
action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)

out_keys = ["loc", "scale"]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
),
)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.single_action_spec,
spec=env.action_spec_unbatched,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.single_action_spec,
action_spec=env.action_spec_unbatched,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.single_action_spec,
action_space=env.action_spec_unbatched,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_redq_model(
default_policy_scale = cfg.network.default_policy_scale
gSDE = cfg.exploration.gSDE

action_spec = proof_environment.single_action_spec
action_spec = proof_environment.action_spec_unbatched

if actor_net_kwargs is None:
actor_net_kwargs = {}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
"""Make SAC agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
42 changes: 8 additions & 34 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,17 +1388,17 @@ def _make_specs(self):
obs_spec_unlazy = consolidate_spec(obs_specs)
action_specs = torch.stack(action_specs, dim=0)

self.unbatched_observation_spec = Composite(
self.observation_spec_unbatched = Composite(
lazy=obs_spec_unlazy,
state=Unbounded(shape=(64, 64, 3)),
device=self.device,
)

self.single_action_spec = Composite(
self.action_spec_unbatched = Composite(
lazy=action_specs,
device=self.device,
)
self.unbatched_reward_spec = Composite(
self.reward_spec_unbatched = Composite(
{
"lazy": Composite(
{"reward": Unbounded(shape=(self.n_nested_dim, 1))},
Expand All @@ -1407,7 +1407,7 @@ def _make_specs(self):
},
device=self.device,
)
self.unbatched_done_spec = Composite(
self.done_spec_unbatched = Composite(
{
"lazy": Composite(
{
Expand All @@ -1423,19 +1423,6 @@ def _make_specs(self):
device=self.device,
)

self.action_spec = self.single_action_spec.expand(
*self.batch_size, *self.single_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def get_agent_obs_spec(self, i):
camera = Bounded(low=0, high=200, shape=(7, 7, 3))
vector_3d = Unbounded(shape=(3,))
Expand Down Expand Up @@ -1610,21 +1597,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):

self.make_specs()

self.action_spec = self.single_action_spec.expand(
*self.batch_size, *self.single_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def make_specs(self):
self.unbatched_observation_spec = Composite(
self.observation_spec_unbatched = Composite(
nested_1=Composite(
observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)),
shape=(self.nested_dim_1,),
Expand All @@ -1642,7 +1616,7 @@ def make_specs(self):
),
)

self.single_action_spec = Composite(
self.action_spec_unbatched = Composite(
nested_1=Composite(
action=Categorical(n=2, shape=(self.nested_dim_1,)),
shape=(self.nested_dim_1,),
Expand All @@ -1654,7 +1628,7 @@ def make_specs(self):
action=OneHot(n=2),
)

self.unbatched_reward_spec = Composite(
self.reward_spec_unbatched = Composite(
nested_1=Composite(
gift=Unbounded(shape=(self.nested_dim_1, 1)),
shape=(self.nested_dim_1,),
Expand All @@ -1666,7 +1640,7 @@ def make_specs(self):
reward=Unbounded(shape=(1,)),
)

self.unbatched_done_spec = Composite(
self.done_spec_unbatched = Composite(
nested_1=Composite(
done=Categorical(
n=2,
Expand Down
24 changes: 24 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
DistributionalQValueActor,
OneHotCategorical,
QValueActor,
recurrent_mode,
SafeSequential,
WorldModelWrapper,
)
Expand Down Expand Up @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs):


class TestBase:
def test_decorators(self):
class MyLoss(LossModule):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

loss = MyLoss()
loss.forward(None)
loss.actor_loss(None)
loss.something_loss(None)
assert not recurrent_mode()

@pytest.mark.parametrize("expand_dim", [None, 2])
@pytest.mark.parametrize("compare_against", [True, False])
@pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion")
Expand Down
Loading

0 comments on commit b59d5de

Please sign in to comment.