From 6f78a4ad4ca3e5bbc51687b66fe3e9930bae4dca Mon Sep 17 00:00:00 2001 From: Bryan Lim <46229436+limbryan@users.noreply.github.com> Date: Thu, 30 Mar 2023 00:01:24 +0100 Subject: [PATCH] chore: uniform sac network sizes (#145) * Uniformize architecture definition across RL algs. Separate actor and critic architecture definition for all SAC related networks including DADS and DIAYN too --- examples/dads.ipynb | 18 +++++++++++------- examples/diayn.ipynb | 18 +++++++++++------- examples/me_sac_pbt.ipynb | 18 +++++++++++------- examples/sac_pbt.ipynb | 15 +++++++++++---- examples/smerl.ipynb | 18 +++++++++++------- qdax/baselines/dads.py | 2 ++ qdax/baselines/diayn.py | 3 ++- qdax/baselines/sac.py | 7 +++++-- qdax/baselines/sac_pbt.py | 6 ++++-- .../neuroevolution/networks/dads_networks.py | 11 ++++++----- .../neuroevolution/networks/diayn_networks.py | 11 ++++++----- .../neuroevolution/networks/sac_networks.py | 9 +++++---- tests/baselines_test/dads_smerl_test.py | 6 ++++-- tests/baselines_test/dads_test.py | 6 ++++-- tests/baselines_test/diayn_smerl_test.py | 6 ++++-- tests/baselines_test/diayn_test.py | 6 ++++-- tests/baselines_test/me_pbt_sac_test.py | 6 ++++-- tests/baselines_test/pbt_sac_test.py | 6 ++++-- tests/baselines_test/sac_test.py | 6 ++++-- 19 files changed, 113 insertions(+), 65 deletions(-) diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f64f4685..72380b34 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -116,7 +116,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DADS config\n", @@ -202,7 +203,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DADS config\n", " num_skills=num_skills,\n", @@ -520,11 +522,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -538,7 +537,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 10cfda49..c48ab765 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -116,7 +116,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DIAYN config\n", @@ -200,7 +201,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DIAYN config\n", " num_skills=num_skills,\n", @@ -510,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -528,7 +527,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 3f856c3d..6d4dfdfe 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -63,7 +63,8 @@ "episode_length = 1000\n", "tau = 0.005\n", "alpha_init = 1.0\n", - "hidden_layer_sizes = (256, 256)\n", + "critic_hidden_layer_size = (256, 256) \n", + "policy_hidden_layer_size = (256, 256) \n", "fix_alpha = False\n", "normalize_observations = False\n", "\n", @@ -148,7 +149,8 @@ " tau=tau,\n", " normalize_observations=normalize_observations,\n", " alpha_init=alpha_init,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", ")\n", "\n", @@ -527,11 +529,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -545,7 +544,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 37bc06d7..4f225667 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -95,7 +95,8 @@ "grad_updates_per_step = 1.0\n", "tau = 0.005\n", "alpha_init = 1.0\n", - "hidden_layer_sizes = (256, 256)\n", + "critic_hidden_layer_size = (256, 256) \n", + "policy_hidden_layer_size = (256, 256)\n", "fix_alpha = False\n", "normalize_observations = False\n", "\n", @@ -217,7 +218,8 @@ " tau=tau,\n", " normalize_observations=normalize_observations,\n", " alpha_init=alpha_init,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", ")\n", "\n", @@ -544,7 +546,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -558,7 +560,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 8042c8cf..98d57b94 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -117,7 +117,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DIAYN config\n", @@ -212,7 +213,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DIAYN config\n", " num_skills=num_skills,\n", @@ -525,11 +527,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -543,7 +542,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 310a9c8b..41f2ff08 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -84,6 +84,8 @@ def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int): action_size=action_size, descriptor_size=descriptor_size, omit_input_dynamics_dim=config.omit_input_dynamics_dim, + policy_hidden_layer_size=config.policy_hidden_layer_size, + critic_hidden_layer_size=config.critic_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index e2def709..c03cfb3f 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -78,7 +78,8 @@ def __init__(self, config: DiaynConfig, action_size: int): self._policy, self._critic, self._discriminator = make_diayn_networks( num_skills=self._config.num_skills, action_size=action_size, - hidden_layer_sizes=self._config.hidden_layer_sizes, + policy_hidden_layer_size=self._config.policy_hidden_layer_size, + critic_hidden_layer_size=self._config.critic_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index 90a823b6..a5ce15c5 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -71,7 +71,8 @@ class SacConfig: alpha_init: float = 1.0 discount: float = 0.97 reward_scaling: float = 1.0 - hidden_layer_sizes: tuple = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha: bool = False @@ -82,7 +83,9 @@ def __init__(self, config: SacConfig, action_size: int) -> None: # define the networks self._policy, self._critic = make_sac_networks( - action_size=action_size, hidden_layer_sizes=self._config.hidden_layer_sizes + action_size=action_size, + critic_hidden_layer_size=self._config.critic_hidden_layer_size, + policy_hidden_layer_size=self._config.policy_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index f5fd24c1..9aa2ff4c 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -110,7 +110,8 @@ class PBTSacConfig: tau: float = 0.005 normalize_observations: bool = False alpha_init: float = 1.0 - hidden_layer_sizes: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) fix_alpha: bool = False @@ -123,7 +124,8 @@ def __init__(self, config: PBTSacConfig, action_size: int) -> None: tau=config.tau, normalize_observations=config.normalize_observations, alpha_init=config.alpha_init, - hidden_layer_sizes=config.hidden_layer_sizes, + policy_hidden_layer_size=config.policy_hidden_layer_size, + critic_hidden_layer_size=config.critic_hidden_layer_size, fix_alpha=config.fix_alpha, # unused default values for parameters that will be learnt as part of PBT learning_rate=3e-4, diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index 785e4ef3..beb4b77a 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -126,7 +126,8 @@ def __call__( def make_dads_networks( action_size: int, descriptor_size: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), omit_input_dynamics_dim: int = 2, identity_covariance: bool = True, dynamics_initializer: Optional[Initializer] = None, @@ -155,7 +156,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -167,7 +168,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -176,7 +177,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -191,7 +192,7 @@ def _dynamics_fn( obs: StateDescriptor, skill: Skill, target: StateDescriptor ) -> jnp.ndarray: dynamics_network = DynamicsNetwork( - hidden_layer_sizes, + critic_hidden_layer_size, descriptor_size, omit_input_dynamics_dim=omit_input_dynamics_dim, identity_covariance=identity_covariance, diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index dc45d298..c656cace 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -10,7 +10,8 @@ def make_diayn_networks( action_size: int, num_skills: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), ) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: """Creates networks used in DIAYN. @@ -30,7 +31,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -42,7 +43,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -51,7 +52,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -66,7 +67,7 @@ def _discriminator_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [num_skills], + list(critic_hidden_layer_size) + [num_skills], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index be6db1b2..dcadfaa2 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -9,7 +9,8 @@ def make_sac_networks( action_size: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), ) -> Tuple[hk.Transformed, hk.Transformed]: """Creates networks used in SAC. @@ -27,7 +28,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -39,7 +40,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -48,7 +49,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 767407ab..1e782f2a 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -31,7 +31,8 @@ def test_dads_smerl() -> None: tau = 0.005 grad_updates_per_step = 0.25 normalize_observations = False - hidden_layer_sizes = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) alpha_init = 1.0 fix_alpha = False discount = 0.97 @@ -102,7 +103,8 @@ def test_dads_smerl() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DADS config num_skills=num_skills, diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 54f486d6..0b9af46e 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -30,7 +30,8 @@ def test_dads() -> None: tau = 0.005 grad_updates_per_step = 0.25 normalize_observations = False - hidden_layer_sizes = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) alpha_init = 1.0 fix_alpha = False discount = 0.97 @@ -93,7 +94,8 @@ def test_dads() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DADS config num_skills=num_skills, diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index bb75c37e..abd94b45 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -34,7 +34,8 @@ def test_diayn_smerl() -> None: alpha_init = 1.0 discount = 0.97 reward_scaling = 1.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False normalize_observations = False # DIAYN config @@ -100,7 +101,8 @@ def test_diayn_smerl() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DIAYN config num_skills=num_skills, diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index c0dd8c09..3492d9c1 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -33,7 +33,8 @@ def test_diayn() -> None: alpha_init = 1.0 discount = 0.97 reward_scaling = 1.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False normalize_observations = False # DIAYN config @@ -85,7 +86,8 @@ def test_diayn() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DIAYN config num_skills=num_skills, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 5bbc8241..079fde45 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -27,7 +27,8 @@ def test_me_pbt_sac() -> None: episode_length = 100 tau = 0.005 alpha_init = 1.0 - hidden_layer_sizes = (64, 64) + policy_hidden_layer_size = (64, 64) + critic_hidden_layer_size = (64, 64) fix_alpha = False normalize_observations = False @@ -79,7 +80,8 @@ def test_me_pbt_sac() -> None: tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, - hidden_layer_sizes=hidden_layer_sizes, + policy_hidden_layer_size=policy_hidden_layer_size, + critic_hidden_layer_size=critic_hidden_layer_size, fix_alpha=fix_alpha, ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index d8cbcf58..c83f277c 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -31,7 +31,8 @@ def test_pbt_sac() -> None: grad_updates_per_step = 1.0 tau = 0.005 alpha_init = 1.0 - hidden_layer_sizes = (64, 64) + policy_hidden_layer_size = (64, 64) + critic_hidden_layer_size = (64, 64) fix_alpha = False normalize_observations = False @@ -89,7 +90,8 @@ def init_environments(random_key): # type: ignore tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, - hidden_layer_sizes=hidden_layer_sizes, + policy_hidden_layer_size=policy_hidden_layer_size, + critic_hidden_layer_size=critic_hidden_layer_size, fix_alpha=fix_alpha, ) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index f4029beb..c667aa66 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -31,7 +31,8 @@ def test_sac() -> None: alpha_init = 1.0 discount = 0.95 reward_scaling = 10.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False # Initialize environments @@ -73,7 +74,8 @@ def test_sac() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, )