-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
104 lines (79 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import neat
import gym
import pickle
import multiprocessing as mp
import visualize
gym.logger.set_level(40) # Disable gym warnings
os.chdir("./checkpoints") # To store the checkpoints in this folder
# Learning Parameters
NUM_GENERATIONS = 1000
PARALLEL = 2 # Number of environments to run at once
ENV = "MsPacman-ram-v0" # RAM means number of inputs 128
CONFIG_FILE = "../config"
class Train:
def __init__(self, generations, parallel=2):
self.generations = generations
self.par = parallel
@staticmethod
def _fitness_func(genome, config, o):
env = gym.make(ENV)
try:
state = env.reset()
net = neat.nn.FeedForwardNetwork.create(genome, config)
done = False
total_reward = 0
while not done:
# Pass input through neural network
state = state.flatten()
output = net.activate(state)
action = output.index(max(output))
observation, reward, done, info = env.step(action)
state = observation
total_reward += reward
# env.render() # Uncomment this if you want the game to show when training
fitness = total_reward
o.put(fitness)
# if index % 30 == 0:
# print(f"Genome {index}. Fitness {total_reward}")
if fitness >= 500:
pickle.dump(genome, open("finisher.pkl", "wb")) # Save a good model just in case of a crash
env.close()
# To easily stop the training
except KeyboardInterrupt:
env.close()
exit()
def _eval_genomes(self, genomes, config):
idx, genomes = zip(*genomes)
for i in range(0, len(genomes), self.par):
output = mp.Queue()
processes = [mp.Process(target=self._fitness_func, args=(genome, config, output)) for genome in
genomes[i:i + self.par]] # Define all the processes
# Run the processes
[p.start() for p in processes]
[p.join() for p in processes]
results = [output.get() for _ in processes]
for n, r in enumerate(results):
genomes[i + n].fitness = r
def _run(self, config_file, generations):
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_file)
# p = neat.Population(config)
p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-408")
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(5))
winner = p.run(self._eval_genomes, generations)
pickle.dump(winner, open('winner.pkl', 'wb'))
visualize.draw_net(config, winner, True)
visualize.plot_stats(stats, ylog=False, view=True)
visualize.plot_species(stats, view=True)
def main(self, config_file=CONFIG_FILE):
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, config_file)
self._run(config_path, self.generations)
if __name__ == "__main__":
t = Train(NUM_GENERATIONS)
t.main()