-
Notifications
You must be signed in to change notification settings - Fork 22
/
database.py
68 lines (58 loc) · 2.03 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import gc
import time
class database:
def __init__(self, size, input_dims):
#create database with input_dims as list of input dimensions
self.size = size
self.states = np.zeros([self.size,84,84],dtype='float') #image dimensions
self.actions = np.zeros(self.size,dtype='float')
self.terminals = np.zeros(self.size,dtype='float')
#self.nextstates = np.zeros([self.size,input_dims[0],input_dims[1],input_dims[2]],dtype='float')
self.rewards = np.zeros(self.size,dtype='float')
self.counter = 0 #keep track of next empty state
self.batch_counter = 0
self.rand_idxs = np.arange(3,300)
self.flag = False
return
def get_four(self,idx):
four_s = np.zeros([84,84,4])
four_n = np.zeros([84,84,4])
for i in range(0,4):
four_s[:,:,i] = self.states[idx-3+i]
four_n[:,:,i] = self.states[idx-2+i]
return four_s,self.actions[idx],self.terminals[idx],four_n,self.rewards[idx]
def get_batches(self, bat_size):
bat_s = np.zeros([bat_size,84,84,4])
bat_a = np.zeros([bat_size])
bat_t = np.zeros([bat_size])
bat_n = np.zeros([bat_size,84,84,4])
bat_r = np.zeros([bat_size])
ss = time.time()
for i in range(bat_size):
if self.batch_counter >= len(self.rand_idxs) - bat_size :
self.rand_idxs = np.arange(3,self.get_size()-1)
np.random.shuffle(self.rand_idxs)
self.batch_counter = 0
s,a,t,n,r = self.get_four(self.rand_idxs[self.batch_counter])
bat_s[i] = s; bat_a[i] = a; bat_t[i] = t; bat_n[i] = n; bat_r[i] = r
self.batch_counter += 1
e3 = time.time()-ss
return bat_s,bat_a,bat_t,bat_n,bat_r
def insert(self, prevstate_proc,reward,action,terminal):
self.states[self.counter] = prevstate_proc
#self.nextstates[self.counter] = newstate_proc
self.rewards[self.counter] = reward
self.actions[self.counter] = action
self.terminals[self.counter] = terminal
#update counter
self.counter += 1
if self.counter >= self.size:
self.flag = True
self.counter = 0
return
def get_size(self):
if self.flag == False:
return self.counter
else:
return self.size