-
Notifications
You must be signed in to change notification settings - Fork 5
/
rpm.py
83 lines (69 loc) · 2.58 KB
/
rpm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# from collections import deque
import numpy as np
import random
import pickle
# replay buffer per http://pemami4911.github.io/blog/2016/08/21/ddpg-rl.html
class rpm(object):
#replay memory
def __init__(self,buffer_size):
self.buffer_size = buffer_size
self.buffer = []
self.index = 0
import threading
self.lock = threading.Lock()
def add(self, obj):
self.lock.acquire()
if self.size() > self.buffer_size:
# self.buffer.popleft()
# self.buffer = self.buffer[1:]
# self.buffer.pop(0)
#trim
print('buffer size larger than set value, trimming...')
self.buffer = self.buffer[(self.size()-self.buffer_size):]
elif self.size() == self.buffer_size:
self.buffer[self.index] = obj
self.index += 1
self.index %= self.buffer_size
else:
self.buffer.append(obj)
self.lock.release()
def size(self):
return len(self.buffer)
def sample_batch(self, batch_size):
'''
batch_size specifies the number of experiences to add
to the batch. If the replay buffer has less than batch_size
elements, simply return all of the elements within the buffer.
Generally, you'll want to wait until the buffer has at least
batch_size elements before beginning to sample from it.
'''
if self.size() < batch_size:
batch = random.sample(self.buffer, self.size())
else:
batch = random.sample(self.buffer, batch_size)
item_count = len(batch[0])
res = []
for i in range(item_count):
# k = np.array([item[i] for item in batch])
# if len(batch[0][i])>0:
if isinstance(batch[0][i],tuple):
k = []
for j in range(len(batch[0][i])):
k.append(
np.stack((item[i][j] for item in batch),axis=0)
)
else:
k = np.stack((item[i] for item in batch),axis=0)
# if len(k.shape)==1: k = k.reshape(k.shape+(1,))
if len(k.shape)==1:
k.shape+=(1,)
res.append(k)
return res
def save(self, pathname):
self.lock.acquire()
pickle.dump([self.buffer,self.index], open(pathname, 'wb'))
print('memory dumped into',pathname)
self.lock.release()
def load(self, pathname):
[self.buffer,self.index] = pickle.load(open(pathname, 'rb'))
print('memory loaded from',pathname)