-
Notifications
You must be signed in to change notification settings - Fork 0
/
CubeEnv.py
88 lines (71 loc) · 2.43 KB
/
CubeEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
'''
Wraps the Cube class into an OpenAI environment
'''
from gym import Env,Space
from Cube import RubiksCube
import numpy as np
import matplotlib.animation as animation
class CubeEnv(Env):
'''
Initialises a Cube environment
'''
def __init__(self,n=3):
self.cube = RubiksCube(n)
self.action_space = CubeActionSpace(self.cube)
self.observation_space = self.cube._cube
self.moves = self.action_space.moves
self.metadata = {"render.modes":["human","rgb_array"]}
self.score = self.cube.score_similarity()
self.num_moves = 0
reward_range = 0,1
def _step(self,action):
'''
Implements a single action. This moves the cube using a given state
'''
index = np.argmax(action)
self.cube.rotate_cube(*self.moves[index])
observation = self.cube._cube
score = self.cube.score_similarity()
reward = np.max(score,axis=1)-np.max(self.score,axis=1)
self.score = score
done = np.mean(score) == 1
self.num_moves+=1
return observation,np.mean(reward),done,self.num_moves
def _reset(self):
return self.cube.reset(100)
def _render(self,mode="rgb_array",close=False):
if mode=="rgb_array":
return self.cube.cube_colours()
if close:
return None
def _close(self):
return None
def _seed(self,seed=None):
if seed:
np.random.RandomState(seed)
def init_figure(self):
self.cube.show_layout()
class CubeActionSpace(Space):
def __init__(self,cube):
self.cube = cube
self.moves = []
for i in range(self.cube.n):
for j in range(self.cube.n):
for k in [-1,1]:
self.moves.append((i,j,k))
self.moves = np.array(self.moves)
self.shape = self.moves.shape
self.high = len(self.moves)/2
self.low = -len(self.moves)/2
def sample(self):
return np.random.randint(len(self.moves))
def contains(self,x):
return x in range(len(self.moves))
class CubeObservationSpace(Space):
def __init__(self,cube):
self.cube = cube
self.obvs = [self.cube.score_similarity()]
def sample(self):
return 0
def contains(self,x):
return x == 0