-
Notifications
You must be signed in to change notification settings - Fork 1
/
rl_agent.py
44 lines (29 loc) · 1011 Bytes
/
rl_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
""" RL Agent
"""
import numpy as np
from battleship import Board
class RLAgent():
def __init__(self, board, model, gym_env):
if board is None:
self.board = Board()
else:
self.board = board
self.model = model
self.env = gym_env
self.obs = self.env.reset()
self.env._overwrite_board(board)
def play_until_completion(self, debug=False):
""" Plays game until complete. Returns score (torpedo count)
"""
reward_list = list()
episode_reward = 0
while True:
action, _states = self.model.predict(self.obs)
# print(f"action: {action}")
self.obs, reward, done, info = self.env.step(action)
episode_reward += reward
if done:
reward_list.append(episode_reward)
episode_reward = 0
break
return self.board.score(), episode_reward, reward_list