-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
36 lines (31 loc) · 1.57 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import jax
import jax.numpy as jnp
import haiku as hk
#Defining Actor Network
class Actor(hk.Module):
def __init__(self, action_dim):
super().__init__()
self.action_dim = action_dim
def __call__(self, state):
#create a feedforward neural network for the actor
# Initialize the layer's weights using VarianceScaling with a uniform distribution.
x = hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(state)
x = jax.nn.relu(x)
x = hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(x)
x = jax.nn.relu(x)
x = hk.Linear(self.action_dim, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(x)
return jax.nn.tanh(x) # Apply tanh activation to output
#Defining Critic Network
class Critic(hk.Module):
def __init__(self):
super().__init__()
def __call__(self, state, action):
# Concatenate state and action for input to the critic network
x = jnp.concatenate([state, action], axis=-1)
#create a feedforward neural network for the critic
x = hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(x)
x = jax.nn.relu(x)
x = hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(x)
x = jax.nn.relu(x)
x = hk.Linear(1, w_init=hk.initializers.VarianceScaling(scale=1.0, distribution='uniform'))(x)
return x # Output the critic's value estimate