-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualization.py
115 lines (84 loc) · 3.24 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from brain import *
from map import *
from utils import *
import pygame
import time
import itertools
import matplotlib.pyplot as plt
import sys
pygame.font.init()
myfont = pygame.font.SysFont("Comic Sans MS", 30)
class HumanInput(Brain):
def __init__(self):
pass
def predict_move(self, map: Map, *args, **kwargs):
dir = None
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit(1)
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_LEFT or event.key == ord("a"):
dir = Direction.LEFT
if event.key == pygame.K_DOWN or event.key == ord("s"):
dir = Direction.UP
if event.key == pygame.K_RIGHT or event.key == ord("d"):
dir = Direction.RIGHT
if event.key == pygame.K_UP or event.key == ord("w"):
dir = Direction.DOWN
return dir
class Visualization:
def __init__(self, map: Map, brain: Brain, save_frames=False):
self.map = map
self.brain = brain
self.save_frames = save_frames
def run(self):
screen = pygame.display.set_mode([self.map.w * 20, self.map.h * 20])
pygame.init()
direction = Direction.RIGHT
frames = itertools.count(1000)
while True:
direction = self.brain.predict_move(self.map) or direction
points, end = self.map.move(direction.value)
# drawing part
screen.fill((255, 255, 255))
for cell in self.map.walls:
pygame.draw.rect(screen, (0, 0, 0), (cell[0] * 20, cell[1] * 20, 20, 20))
cell = self.map.apple
pygame.draw.rect(screen, (200, 0, 0), (cell[0] * 20, cell[1] * 20, 20, 20))
for cell in self.map.snake:
pygame.draw.rect(screen, (0, 200, 0), (cell[0] * 20, cell[1] * 20, 20, 20))
cell = self.map.snake[0]
pygame.draw.rect(screen, (0, 150, 0), (cell[0] * 20, cell[1] * 20, 20, 20))
textsurface = myfont.render(str((points)), False, (255, 255, 255))
screen.blit(textsurface, (5, 0))
if self.save_frames:
pygame.image.save(screen, f"tmp/frame{next(frames)}.png")
pygame.display.flip()
if end:
time.sleep(1)
if isinstance(self.brain, HumanInput):
self.map.restart()
else:
print(points)
break
time.sleep(1 / 5)
class Plots:
def __init__(self, history: list[np.ndarray]):
self.history = history
def plot_best_per_epoch(self):
plt.plot(np.max(self.history, axis=1))
plt.title("Best Fitness Value per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Fitness Value (snake's length)")
plt.show()
def plot_pop_quality_per_epoch(self):
plt.boxplot(self.history)
plt.title("Population Fitness Value per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Fitness Value (snake's length)")
plt.show()
if __name__ == "__main__":
map = Map()
brain = HumanInput()
V = Visualization(map, brain, save_frames=False)
V.run()