Skip to content

Commit

Permalink
chore: uniform sac network sizes (#145)
Browse files Browse the repository at this point in the history
* Uniformize architecture definition across RL algs. Separate actor and critic architecture definition for all SAC related networks including DADS and DIAYN too
  • Loading branch information
limbryan authored Mar 29, 2023
1 parent 6fa19e7 commit 6f78a4a
Show file tree
Hide file tree
Showing 19 changed files with 113 additions and 65 deletions.
18 changes: 11 additions & 7 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@
"alpha_init = 1.0 #@param {type:\"number\"}\n",
"discount = 0.97 #@param {type:\"number\"}\n",
"reward_scaling = 1.0 #@param {type:\"number\"}\n",
"hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n",
"critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"fix_alpha = False #@param {type:\"boolean\"}\n",
"normalize_observations = False #@param {type:\"boolean\"}\n",
"# DADS config\n",
Expand Down Expand Up @@ -202,7 +203,8 @@
" alpha_init=alpha_init,\n",
" discount=discount,\n",
" reward_scaling=reward_scaling,\n",
" hidden_layer_sizes=hidden_layer_sizes,\n",
" critic_hidden_layer_size=critic_hidden_layer_size,\n",
" policy_hidden_layer_size=policy_hidden_layer_size,\n",
" fix_alpha=fix_alpha,\n",
" # DADS config\n",
" num_skills=num_skills,\n",
Expand Down Expand Up @@ -520,11 +522,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.9.2 64-bit ('3.9.2')",
"language": "python",
"name": "python3"
},
Expand All @@ -538,7 +537,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.2"
},
"vscode": {
"interpreter": {
"hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534"
}
}
},
"nbformat": 4,
Expand Down
18 changes: 11 additions & 7 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@
"alpha_init = 1.0 #@param {type:\"number\"}\n",
"discount = 0.97 #@param {type:\"number\"}\n",
"reward_scaling = 1.0 #@param {type:\"number\"}\n",
"hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n",
"critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"fix_alpha = False #@param {type:\"boolean\"}\n",
"normalize_observations = False #@param {type:\"boolean\"}\n",
"# DIAYN config\n",
Expand Down Expand Up @@ -200,7 +201,8 @@
" alpha_init=alpha_init,\n",
" discount=discount,\n",
" reward_scaling=reward_scaling,\n",
" hidden_layer_sizes=hidden_layer_sizes,\n",
" critic_hidden_layer_size=critic_hidden_layer_size,\n",
" policy_hidden_layer_size=policy_hidden_layer_size,\n",
" fix_alpha=fix_alpha,\n",
" # DIAYN config\n",
" num_skills=num_skills,\n",
Expand Down Expand Up @@ -510,11 +512,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.9.2 64-bit ('3.9.2')",
"language": "python",
"name": "python3"
},
Expand All @@ -528,7 +527,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.2"
},
"vscode": {
"interpreter": {
"hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534"
}
}
},
"nbformat": 4,
Expand Down
18 changes: 11 additions & 7 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"episode_length = 1000\n",
"tau = 0.005\n",
"alpha_init = 1.0\n",
"hidden_layer_sizes = (256, 256)\n",
"critic_hidden_layer_size = (256, 256) \n",
"policy_hidden_layer_size = (256, 256) \n",
"fix_alpha = False\n",
"normalize_observations = False\n",
"\n",
Expand Down Expand Up @@ -148,7 +149,8 @@
" tau=tau,\n",
" normalize_observations=normalize_observations,\n",
" alpha_init=alpha_init,\n",
" hidden_layer_sizes=hidden_layer_sizes,\n",
" critic_hidden_layer_size=critic_hidden_layer_size,\n",
" policy_hidden_layer_size=policy_hidden_layer_size,\n",
" fix_alpha=fix_alpha,\n",
")\n",
"\n",
Expand Down Expand Up @@ -527,11 +529,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.9.2 64-bit ('3.9.2')",
"language": "python",
"name": "python3"
},
Expand All @@ -545,7 +544,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.2"
},
"vscode": {
"interpreter": {
"hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534"
}
}
},
"nbformat": 4,
Expand Down
15 changes: 11 additions & 4 deletions examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@
"grad_updates_per_step = 1.0\n",
"tau = 0.005\n",
"alpha_init = 1.0\n",
"hidden_layer_sizes = (256, 256)\n",
"critic_hidden_layer_size = (256, 256) \n",
"policy_hidden_layer_size = (256, 256)\n",
"fix_alpha = False\n",
"normalize_observations = False\n",
"\n",
Expand Down Expand Up @@ -217,7 +218,8 @@
" tau=tau,\n",
" normalize_observations=normalize_observations,\n",
" alpha_init=alpha_init,\n",
" hidden_layer_sizes=hidden_layer_sizes,\n",
" critic_hidden_layer_size=critic_hidden_layer_size,\n",
" policy_hidden_layer_size=policy_hidden_layer_size,\n",
" fix_alpha=fix_alpha,\n",
")\n",
"\n",
Expand Down Expand Up @@ -544,7 +546,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.9.2 64-bit ('3.9.2')",
"language": "python",
"name": "python3"
},
Expand All @@ -558,7 +560,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.2"
},
"vscode": {
"interpreter": {
"hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534"
}
}
},
"nbformat": 4,
Expand Down
18 changes: 11 additions & 7 deletions examples/smerl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@
"alpha_init = 1.0 #@param {type:\"number\"}\n",
"discount = 0.97 #@param {type:\"number\"}\n",
"reward_scaling = 1.0 #@param {type:\"number\"}\n",
"hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n",
"critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n",
"fix_alpha = False #@param {type:\"boolean\"}\n",
"normalize_observations = False #@param {type:\"boolean\"}\n",
"# DIAYN config\n",
Expand Down Expand Up @@ -212,7 +213,8 @@
" alpha_init=alpha_init,\n",
" discount=discount,\n",
" reward_scaling=reward_scaling,\n",
" hidden_layer_sizes=hidden_layer_sizes,\n",
" critic_hidden_layer_size=critic_hidden_layer_size,\n",
" policy_hidden_layer_size=policy_hidden_layer_size,\n",
" fix_alpha=fix_alpha,\n",
" # DIAYN config\n",
" num_skills=num_skills,\n",
Expand Down Expand Up @@ -525,11 +527,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.9.2 64-bit ('3.9.2')",
"language": "python",
"name": "python3"
},
Expand All @@ -543,7 +542,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.2"
},
"vscode": {
"interpreter": {
"hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534"
}
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions qdax/baselines/dads.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int):
action_size=action_size,
descriptor_size=descriptor_size,
omit_input_dynamics_dim=config.omit_input_dynamics_dim,
policy_hidden_layer_size=config.policy_hidden_layer_size,
critic_hidden_layer_size=config.critic_hidden_layer_size,
)

# define the action distribution
Expand Down
3 changes: 2 additions & 1 deletion qdax/baselines/diayn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self, config: DiaynConfig, action_size: int):
self._policy, self._critic, self._discriminator = make_diayn_networks(
num_skills=self._config.num_skills,
action_size=action_size,
hidden_layer_sizes=self._config.hidden_layer_sizes,
policy_hidden_layer_size=self._config.policy_hidden_layer_size,
critic_hidden_layer_size=self._config.critic_hidden_layer_size,
)

# define the action distribution
Expand Down
7 changes: 5 additions & 2 deletions qdax/baselines/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class SacConfig:
alpha_init: float = 1.0
discount: float = 0.97
reward_scaling: float = 1.0
hidden_layer_sizes: tuple = (256, 256)
critic_hidden_layer_size: tuple = (256, 256)
policy_hidden_layer_size: tuple = (256, 256)
fix_alpha: bool = False


Expand All @@ -82,7 +83,9 @@ def __init__(self, config: SacConfig, action_size: int) -> None:

# define the networks
self._policy, self._critic = make_sac_networks(
action_size=action_size, hidden_layer_sizes=self._config.hidden_layer_sizes
action_size=action_size,
critic_hidden_layer_size=self._config.critic_hidden_layer_size,
policy_hidden_layer_size=self._config.policy_hidden_layer_size,
)

# define the action distribution
Expand Down
6 changes: 4 additions & 2 deletions qdax/baselines/sac_pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class PBTSacConfig:
tau: float = 0.005
normalize_observations: bool = False
alpha_init: float = 1.0
hidden_layer_sizes: tuple = (256, 256)
policy_hidden_layer_size: tuple = (256, 256)
critic_hidden_layer_size: tuple = (256, 256)
fix_alpha: bool = False


Expand All @@ -123,7 +124,8 @@ def __init__(self, config: PBTSacConfig, action_size: int) -> None:
tau=config.tau,
normalize_observations=config.normalize_observations,
alpha_init=config.alpha_init,
hidden_layer_sizes=config.hidden_layer_sizes,
policy_hidden_layer_size=config.policy_hidden_layer_size,
critic_hidden_layer_size=config.critic_hidden_layer_size,
fix_alpha=config.fix_alpha,
# unused default values for parameters that will be learnt as part of PBT
learning_rate=3e-4,
Expand Down
11 changes: 6 additions & 5 deletions qdax/core/neuroevolution/networks/dads_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __call__(
def make_dads_networks(
action_size: int,
descriptor_size: int,
hidden_layer_sizes: Tuple[int, ...] = (256, 256),
critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
omit_input_dynamics_dim: int = 2,
identity_covariance: bool = True,
dynamics_initializer: Optional[Initializer] = None,
Expand Down Expand Up @@ -155,7 +156,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [2 * action_size],
list(policy_hidden_layer_size) + [2 * action_size],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -167,7 +168,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network1 = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [1],
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -176,7 +177,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network2 = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [1],
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -191,7 +192,7 @@ def _dynamics_fn(
obs: StateDescriptor, skill: Skill, target: StateDescriptor
) -> jnp.ndarray:
dynamics_network = DynamicsNetwork(
hidden_layer_sizes,
critic_hidden_layer_size,
descriptor_size,
omit_input_dynamics_dim=omit_input_dynamics_dim,
identity_covariance=identity_covariance,
Expand Down
11 changes: 6 additions & 5 deletions qdax/core/neuroevolution/networks/diayn_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
def make_diayn_networks(
action_size: int,
num_skills: int,
hidden_layer_sizes: Tuple[int, ...] = (256, 256),
critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]:
"""Creates networks used in DIAYN.
Expand All @@ -30,7 +31,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [2 * action_size],
list(policy_hidden_layer_size) + [2 * action_size],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -42,7 +43,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network1 = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [1],
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -51,7 +52,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network2 = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [1],
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand All @@ -66,7 +67,7 @@ def _discriminator_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes) + [num_skills],
list(critic_hidden_layer_size) + [num_skills],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
Expand Down
Loading

0 comments on commit 6f78a4a

Please sign in to comment.