From f9c4e00805739c940cf71e1a56e06ae29a9d5fbb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 14:44:20 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/multiagent/iql.py | 4 ++-- sota-implementations/multiagent/maddpg_iddpg.py | 8 ++++---- sota-implementations/multiagent/mappo_ippo.py | 6 +++--- sota-implementations/multiagent/qmix_vdn.py | 4 ++-- sota-implementations/multiagent/sac.py | 12 ++++++------ 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 0a5372a62b0..04134bab951 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index d666d0d6982..1485e3e8c0b 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -91,13 +91,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "param")], out_keys=[env.action_key], distribution_class=TanhDelta, distribution_kwargs={ - "low": env.action_spec_unbatched[("agents", "action")].space.low, - "high": env.action_spec_unbatched[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=False, ) @@ -105,7 +105,7 @@ def train(cfg: "DictConfig"): # noqa: F821 policy_explore = TensorDictSequential( policy, AdditiveGaussianModule( - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, device=cfg.train.device, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index 6f50c79b149..06cc2cd1fce 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec_unbatched[("agents", "action")].space.low, - "high": env.action_spec_unbatched[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index a1a3eb35618..6e619179b4b 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index f0b11fe6b9c..694083e5b0f 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821 policy = ProbabilisticActor( module=policy_module, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec_unbatched[("agents", "action")].space.low, - "high": env.action_spec_unbatched[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.action_spec_unbatched, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "logits")], out_keys=[env.action_key], distribution_class=OneHotCategorical @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821 actor_network=policy, qvalue_network=value_module, delay_qvalue=True, - action_spec=env.action_spec_unbatched, + action_spec=env.full_action_spec_unbatched, ) loss_module.set_keys( state_action_value=("agents", "state_action_value"), @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821 qvalue_network=value_module, delay_qvalue=True, num_actions=env.action_spec.space.n, - action_space=env.action_spec_unbatched, + action_space=env.full_action_spec_unbatched, ) loss_module.set_keys( action_value=("agents", "action_value"),