From 727bc4f4da0d8fd26bc0fc51cb6ac7a4d63957e9 Mon Sep 17 00:00:00 2001 From: kslazarev Date: Mon, 21 Jan 2019 17:59:34 +0300 Subject: [PATCH] Reduce memory usage (2-3x times). Replace state matrices (total_next_obs, next_states, total_state, next_obs) by preallocate memory numpy arrays. Remove unnecessary np.float32 converts. Replace array copies by in-place calculations. --- train.py | 57 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/train.py b/train.py index 11720ba..9cd607f 100644 --- a/train.py +++ b/train.py @@ -112,7 +112,7 @@ def main(): parent_conns.append(parent_conn) child_conns.append(child_conn) - states = np.zeros([num_worker, 4, 84, 84]) + states = np.zeros([num_worker, 4, 84, 84], dtype='float32') sample_episode = 0 sample_rall = 0 @@ -124,61 +124,65 @@ def main(): # normalize obs print('Start to initailize observation normalization parameter.....') - next_obs = [] + next_obs = np.zeros([num_worker * num_step, 1, 84, 84]) for step in range(num_step * pre_obs_norm_step): actions = np.random.randint(0, output_size, size=(num_worker,)) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) - for parent_conn in parent_conns: + for idx, parent_conn in enumerate(parent_conns): s, r, d, rd, lr = parent_conn.recv() - next_obs.append(s[3, :, :].reshape([1, 84, 84])) + next_obs[(step % num_step) * num_worker + idx, 0, :, :] = s[3, :, :] - if len(next_obs) % (num_step * num_worker) == 0: - next_obs = np.stack(next_obs) + if (step % num_step) == num_step - 1: obs_rms.update(next_obs) - next_obs = [] + next_obs = np.zeros([num_worker * num_step, 1, 84, 84]) print('End to initalize...') while True: - total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_policy, total_policy_np = \ - [], [], [], [], [], [], [], [], [], [], [] + total_state = np.zeros([num_worker * num_step, 4, 84, 84], dtype='float32') + total_next_obs = np.zeros([num_worker * num_step, 1, 84, 84]) + total_reward, total_done, total_next_state, total_action, total_int_reward, total_ext_values, total_int_values, total_policy, total_policy_np = \ + [], [], [], [], [], [], [], [], [] global_step += (num_worker * num_step) global_update += 1 # Step 1. n-step rollout - for _ in range(num_step): - actions, value_ext, value_int, policy = agent.get_action(np.float32(states) / 255.) + for step in range(num_step): + actions, value_ext, value_int, policy = agent.get_action(states / 255.) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) - next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], [] - for parent_conn in parent_conns: + next_obs = np.zeros([num_worker, 1, 84, 84]) + next_states = np.zeros([num_worker, 4, 84, 84]) + rewards, dones, real_dones, log_rewards = [], [], [], [] + for idx, parent_conn in enumerate(parent_conns): s, r, d, rd, lr = parent_conn.recv() - next_states.append(s) + next_states[idx] = s rewards.append(r) dones.append(d) real_dones.append(rd) log_rewards.append(lr) - next_obs.append(s[3, :, :].reshape([1, 84, 84])) + next_obs[idx, 0] = s[3, :, :] + total_next_obs[idx * num_step + step, 0] = s[3, :, :] - next_states = np.stack(next_states) rewards = np.hstack(rewards) dones = np.hstack(dones) real_dones = np.hstack(real_dones) - next_obs = np.stack(next_obs) # total reward = int reward + ext Reward - intrinsic_reward = agent.compute_intrinsic_reward( - ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5)) + next_obs -= obs_rms.mean + next_obs /= np.sqrt(obs_rms.var) + next_obs.clip(-5, 5, out=next_obs) + intrinsic_reward = agent.compute_intrinsic_reward(next_obs) intrinsic_reward = np.hstack(intrinsic_reward) sample_i_rall += intrinsic_reward[sample_env_idx] - total_next_obs.append(next_obs) + for idx, state in enumerate(states): + total_state[idx * num_step + step] = state total_int_reward.append(intrinsic_reward) - total_state.append(states) total_reward.append(rewards) total_done.append(dones) total_action.append(actions) @@ -207,11 +211,9 @@ def main(): total_int_values.append(value_int) # -------------------------------------------------- - total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84]) total_reward = np.stack(total_reward).transpose().clip(-1, 1) total_action = np.stack(total_action).transpose().reshape([-1]) total_done = np.stack(total_done).transpose() - total_next_obs = np.stack(total_next_obs).transpose([1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84]) total_ext_values = np.stack(total_ext_values).transpose() total_int_values = np.stack(total_int_values).transpose() total_logging_policy = np.vstack(total_policy_np) @@ -260,8 +262,13 @@ def main(): # ----------------------------------------------- # Step 5. Training! - agent.train_model(np.float32(total_state) / 255., ext_target, int_target, total_action, - total_adv, ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5), + total_state /= 255. + total_next_obs -= obs_rms.mean + total_next_obs /= np.sqrt(obs_rms.var) + total_next_obs.clip(-5, 5, out=total_next_obs) + + agent.train_model(total_state, ext_target, int_target, total_action, + total_adv, total_next_obs, total_policy) if global_step % (num_worker * num_step * 100) == 0: