-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSAC_model.py
123 lines (92 loc) · 4.14 KB
/
SAC_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
'''def hidden_init(layer):
fan_in = layer.weight.data.size()[0]
lim = 1. / np.sqrt(fan_in)
return (-lim, lim)'''
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
"""Initialize parameters SoftQNetwork model.
Params
======
state_dim (int): Dimension of each state
hidden_size (int): Number of nodes in the first hidden layer
init_w : Weights & Bias to reset model paremeters with
"""
super(ValueNetwork, self).__init__()
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
"""Build a Value-Network that maps state -> values."""
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
return self.linear3(x)
class SoftQNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
"""Initialize parameters SoftQNetwork model.
Params
======
num_inputs (int): Dimension of each state
num_actions (int): Dimension of each action
hidden_size (int): Number of nodes in the first hidden layer
init_w : Weights & Bias to reset model paremeters with
"""
super(SoftQNetwork, self).__init__()
self.linear1 = nn.Linear(num_inputs+num_actions, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state, action):
"""Build a Soft-Q-Network that maps (state, action) pairs -> Q-values."""
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
return self.linear3(x)
class PolicyNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNetwork, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(num_inputs, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.mean_linear = nn.Linear(hidden_size, num_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_size, num_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def evaluate(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(-1, keepdim=True)
return action, log_prob, z, mean, log_std
def get_action(self, state, device):
state = torch.FloatTensor(state).unsqueeze(0).to(device)
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
action = action.detach().cpu().numpy()
return action[0]