forked from ghliu/pytorch-ddpg
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
34 lines (22 loc) · 830 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.float64
def to_numpy(var):
return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
def to_tensor(ndarray):
return torch.from_numpy(ndarray).type(FLOAT)
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau
)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def process_obs(observation):
pass
def normal_action(action):
pass
def reverse_action(action):
pass