-
Notifications
You must be signed in to change notification settings - Fork 0
/
driver.py
41 lines (31 loc) · 1014 Bytes
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from unityagents import UnityEnvironment
from trainer import Trainer
from tracker import PerformanceTracker
from agent import AgentFactory
from tracker import TrackerFactory
import time
def main():
print("Training the agent ...")
env = UnityEnvironment(file_name='Reacher_Linux/Reacher.x86_64')
agent_factory = AgentFactory()
tracker_factory = TrackerFactory()
trainer = Trainer(env, agent_factory, tracker_factory)
trainer.describe_environment()
time.sleep(5)
agent, tracker = trainer.train(n_episodes=200, plot_every=1000, learn_every=20, iterations_per_learn=10, goal_score=30.0)
print("Training complete!")
time.sleep(5)
print("Running the trained agent ...")
trainer.play(agent)
time.sleep(10)
env.close()
print("Game finished!")
print("Training performance")
tracker.plot_performance()
time.sleep(5)
if __name__ == '__main__':
# import cProfile
# cp = cProfile.Profile()
# cp.enable()
main()
# cp.disable()