-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
105 lines (87 loc) · 3.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from env.wlanenvironment import wlanEnv
from Brain import BrainDQN
import time
import numpy as np
from display import Display
# import signal
CONTROLLER_IP = '10.103.43.130:8080'
BUFFER_LEN = 60
ENV_REFRESH_INTERVAL = 0.1
# def sigint_handler(signum, frame):
# """
# Catch CTRL+C Event
# :param signum:
# :param frame:
# :return:
# """
# global is_sigint_up
# is_sigint_up = True
# print 'Catched interrupt signal .'
#
# signal.signal(signal.SIGINT, sigint_handler)
# is_sigint_up = False
def train():
env = wlanEnv(CONTROLLER_IP, BUFFER_LEN, timeInterval=ENV_REFRESH_INTERVAL)
env.start()
numAPs, numActions, numAdditionDim = env.getDimSpace()
brain = BrainDQN(numActions, numAPs, numAdditionDim, BUFFER_LEN, param_file='saved_networks/network-dqn.params')
while not env.observe()[0]:
time.sleep(0.5)
observation0 = env.observe()[1]
brain.setInitState(observation0)
np.set_printoptions(threshold=5)
print 'Initial observation:\n' + str(observation0)
data = {}
fig = Display(env.id2ap)
fig.display()
try:
while True:
action, q = brain.getAction()
print 'action:\n' + str(action.argmax())
reward, throught, nextObservation = env.step(action)
print 'reward: ' + str(reward) + ', throught: ' + str(throught)
print 'Next observation:\n' + str(nextObservation)
data['timestamp'] = time.time()
data['rssi'] = nextObservation[-1]
data['q'] = q
data['reward'] = reward
data['action_index'] = np.argmax(action)
fig.append(data)
brain.setPerception(nextObservation, action, reward, False)
except KeyboardInterrupt:
print 'Saving replayMemory......'
brain.saveReplayMemory()
fig.stop()
pass
def test():
env = wlanEnv(CONTROLLER_IP, BUFFER_LEN, timeInterval=ENV_REFRESH_INTERVAL, no_guarantee=True)
env.start()
numAPs, numActions, numAdditionDim = env.getDimSpace()
brain = BrainDQN(numActions, numAPs, numAdditionDim, BUFFER_LEN, param_file='saved_networks/network-dqn.params')
while not env.observe()[0]:
time.sleep(0.5)
observation = env.observe()[1]
np.set_printoptions(threshold=5)
data = {}
fig = Display(env.id2ap, PREDICT=True)
fig.display()
try:
while True:
action, q_value, action_index, feature_vector = brain.predict(observation)
print 'action:\n' + str(action_index)
reward, throught, observation = env.step(action)
print 'q_value: ' + str(q_value)
print 'reward: ' + str(reward) + ', throught: ' + str(throught)
data['timestamp'] = time.time()
data['rssi'] = observation[-1]
data['q'] = q_value
data['reward'] = reward
data['action_index'] = action_index
data['feature_vector'] = feature_vector
fig.append(data)
print 'Next observation:\n' + str(observation)
time.sleep(2)
except KeyboardInterrupt:
fig.stop()
if __name__ == '__main__':
test()