Skip to content

Commit

Permalink
Merge pull request #17 from pseudo-rnd-thoughts/np.bool
Browse files Browse the repository at this point in the history
Change np.bool to bool
  • Loading branch information
jjshoots authored Aug 31, 2022
2 parents d7b8bfb + cbf969b commit 5f682c2
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/models/mx_model/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, env, handle, name, eval_obs=None,
self.feature_buf = np.empty((1,) + self.feature_space)
self.action_buf = np.empty(1, dtype=np.int32)
self.advantage_buf, self.value_buf = np.empty(1), np.empty(1)
self.terminal_buf = np.empty(1, dtype=np.bool)
self.terminal_buf = np.empty(1, dtype=bool)

# print("parameters", self.model.get_params())
# mx.viz.plot_network(self.output).view()
Expand Down
4 changes: 2 additions & 2 deletions examples/models/mx_model/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, env, handle, name,
self.replay_buf_feature = ReplayBuffer(shape=(memory_size,) + self.feature_space)
self.replay_buf_action = ReplayBuffer(shape=(memory_size,), dtype=np.int32)
self.replay_buf_reward = ReplayBuffer(shape=(memory_size,))
self.replay_buf_terminal = ReplayBuffer(shape=(memory_size,), dtype=np.bool)
self.replay_buf_terminal = ReplayBuffer(shape=(memory_size,), dtype=bool)
self.replay_buf_mask = ReplayBuffer(shape=(memory_size,))
# if mask[i] == 0, then the item is used for padding, not for training

Expand Down Expand Up @@ -259,7 +259,7 @@ def _add_to_replay_buffer(self, sample_buffer):
m = len(r)

mask = np.ones((m,))
terminal = np.zeros((m,), dtype=np.bool)
terminal = np.zeros((m,), dtype=bool)
if episode.terminal:
terminal[-1] = True
else:
Expand Down
4 changes: 2 additions & 2 deletions examples/models/tf_model/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def out_action(qvalues):
self.replay_buf_feature = ReplayBuffer(shape=(memory_size,) + self.feature_space)
self.replay_buf_action = ReplayBuffer(shape=(memory_size,), dtype=np.int32)
self.replay_buf_reward = ReplayBuffer(shape=(memory_size,))
self.replay_buf_terminal = ReplayBuffer(shape=(memory_size,), dtype=np.bool)
self.replay_buf_terminal = ReplayBuffer(shape=(memory_size,), dtype=bool)
self.replay_buf_mask = ReplayBuffer(shape=(memory_size,))
# if mask[i] == 0, then the item is used for padding, not for training

Expand Down Expand Up @@ -256,7 +256,7 @@ def _add_to_replay_buffer(self, sample_buffer):
m = len(r)

mask = np.ones((m,))
terminal = np.zeros((m,), dtype=np.bool)
terminal = np.zeros((m,), dtype=bool)
if episode.terminal:
terminal[-1] = True
else:
Expand Down
12 changes: 6 additions & 6 deletions examples/models/tf_model/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, env, handle, name,
self.view_buf = np.empty((1,) + self.view_space)
self.feature_buf = np.empty((1,) + self.feature_space)
self.action_buf, self.reward_buf = np.empty(1, dtype=np.int32), np.empty(1)
self.terminal_buf = np.empty(1, dtype=np.bool)
self.terminal_buf = np.empty(1, dtype=bool)

def _create_network(self, input_view, input_feature, reuse=None):
"""define computation graph of network
Expand Down Expand Up @@ -285,7 +285,7 @@ def _add_to_replay_buffer(self, sample_buffer):
m = len(r)

mask = np.ones((m,))
terminal = np.zeros((m,), dtype=np.bool)
terminal = np.zeros((m,), dtype=bool)
if episode.terminal:
terminal[-1] = True
else:
Expand Down Expand Up @@ -335,7 +335,7 @@ def train(self, sample_buffer, print_every=500):
batch_feature = np.zeros((max_+1,) + self.feature_space, dtype=np.float32)
batch_action = np.zeros((max_,), dtype=np.int32)
batch_reward = np.zeros((max_,), dtype=np.float32)
batch_terminal = np.zeros((max_,), dtype=np.bool)
batch_terminal = np.zeros((max_,), dtype=bool)
batch_mask = np.zeros((max_,), dtype=np.float32)

# calc batch number
Expand Down Expand Up @@ -443,10 +443,10 @@ def train_keep_hidden(self, sample_buffer, print_every=500):
batch_feature = np.zeros((max_+1,) + self.feature_space, dtype=np.float32)
batch_action = np.zeros((max_,), dtype=np.int32)
batch_reward = np.zeros((max_,), dtype=np.float32)
batch_terminal = np.zeros((max_,), dtype=np.bool)
batch_terminal = np.zeros((max_,), dtype=bool)
batch_mask = np.zeros((max_,), dtype=np.float32)
batch_hidden = np.zeros((batch_size, self.state_size), dtype=np.float32)
batch_pick = np.zeros((batch_size, max_len), dtype=np.bool)
batch_pick = np.zeros((batch_size, max_len), dtype=bool)
pick_buffer = np.arange(max_, dtype=np.int32)

print("batches: %d add: %d replay_len: %d, %d/%d" %
Expand Down Expand Up @@ -486,7 +486,7 @@ def train_keep_hidden(self, sample_buffer, print_every=500):
rows.append(row)
n_rows = len(rows)

batch_reset = np.zeros((train_length, batch_size), dtype=np.bool)
batch_reset = np.zeros((train_length, batch_size), dtype=bool)
batch_mask[:] = 0

# copy from replay buffer to batch buffer
Expand Down
2 changes: 1 addition & 1 deletion magent/gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_alive(self, handle):
whether the agents are alive
"""
n = self.get_num(handle)
buf = np.empty((n,), dtype=np.bool)
buf = np.empty((n,), dtype=bool)
_LIB.env_get_info(self.game, handle, b"alive",
buf.ctypes.data_as(ctypes.POINTER(ctypes.c_bool)))
return buf
Expand Down

0 comments on commit 5f682c2

Please sign in to comment.