-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn_agent2.py
28 lines (23 loc) · 1021 Bytes
/
dqn_agent2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
# setting seeds for result reproducibility. This is not super important
tf.set_random_seed(2212)
class DQNAgent:
def __init__(self, sess, action_dim, observation_dim):
# Force keras to use the session that we have created
K.set_session(sess)
self.sess = sess
self.action_dim = action_dim
self.observation_dim = observation_dim
self.model = self.create_model()
def create_model(self):
state_input = Input(shape=(self.observation_dim))
state_h1 = Dense(400, activation='relu')(state_input)
state_h2 = Dense(300, activation='relu')(state_h1)
output = Dense(self.action_dim, activation='linear')(state_h2)
model = Model(inputs=state_input, outputs=output)
model.compile(loss='mse', optimizer=Adam(0.005))
return model