diff --git a/algorithms/bc/sac_agent.py b/algorithms/bc/sac_agent.py index 3608389e..876c12c5 100644 --- a/algorithms/bc/sac_agent.py +++ b/algorithms/bc/sac_agent.py @@ -116,8 +116,9 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]): def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" - experiences = self.memory.sample() - demos = self.demo_memory.sample() + self.update_step += 1 + + experiences, demos = self.memory.sample(), self.demo_memory.sample() states, actions, rewards, next_states, dones = experiences demo_states, demo_actions, _, _, _ = demos @@ -170,7 +171,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]: vf_loss.backward() self.vf_optimizer.step() - if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: + if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # bc loss qf_mask = torch.gt( self.qf_1(demo_states, demo_actions), @@ -224,7 +225,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]: ) def write_log( - self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1 + self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1 ): """Write log about loss and score""" total_loss = loss.sum() @@ -239,7 +240,7 @@ def write_log( self.total_step, score, total_loss, - loss[0] * delayed_update, # actor loss + loss[0] * policy_update_freq, # actor loss loss[1], # qf_1 loss loss[2], # qf_2 loss loss[3], # vf loss @@ -253,7 +254,7 @@ def write_log( { "score": score, "total loss": total_loss, - "actor loss": loss[0] * delayed_update, + "actor loss": loss[0] * policy_update_freq, "qf_1 loss": loss[1], "qf_2 loss": loss[2], "vf loss": loss[3], diff --git a/algorithms/fd/sac_agent.py b/algorithms/fd/sac_agent.py index 30810daa..4bf98273 100644 --- a/algorithms/fd/sac_agent.py +++ b/algorithms/fd/sac_agent.py @@ -79,6 +79,8 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]): # pylint: disable=too-many-statements def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" + self.update_step += 1 + experiences = self.memory.sample(self.beta) states, actions, rewards, next_states, dones, weights, indices, eps_d = ( experiences @@ -149,7 +151,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]: vf_loss.backward() self.vf_optimizer.step() - if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: + if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss_element_wise = alpha * log_prob - advantage @@ -212,6 +214,9 @@ def pretrain(self): avg_loss = np.vstack(pretrain_loss).mean(axis=0) pretrain_loss.clear() self.write_log( - 0, avg_loss, 0, delayed_update=self.hyper_params["DELAYED_UPDATE"] + 0, + avg_loss, + 0, + policy_update_freq=self.hyper_params["POLICY_UPDATE_FREQ"], ) print("[INFO] Pre-Train Complete!\n") diff --git a/algorithms/sac/agent.py b/algorithms/sac/agent.py index 995cbb6e..e64dfcc3 100644 --- a/algorithms/sac/agent.py +++ b/algorithms/sac/agent.py @@ -47,6 +47,7 @@ class Agent(AbstractAgent): hyper_params (dict): hyper-parameters total_step (int): total step numbers episode_step (int): step number of the current episode + update_step (int): step number of updates i_episode (int): current episode number """ @@ -80,6 +81,7 @@ def __init__( self.curr_state = np.zeros((1,)) self.total_step = 0 self.episode_step = 0 + self.update_step = 0 self.i_episode = 0 # automatic entropy tuning @@ -153,6 +155,8 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]): def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" + self.update_step += 1 + experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) @@ -203,7 +207,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]: vf_loss.backward() self.vf_optimizer.step() - if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: + if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() @@ -280,7 +284,7 @@ def save_params(self, n_episode: int): AbstractAgent.save_params(self, params, n_episode) def write_log( - self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1 + self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1 ): """Write log about loss and score""" total_loss = loss.sum() @@ -295,7 +299,7 @@ def write_log( self.total_step, score, total_loss, - loss[0] * delayed_update, # actor loss + loss[0] * policy_update_freq, # actor loss loss[1], # qf_1 loss loss[2], # qf_2 loss loss[3], # vf loss @@ -308,7 +312,7 @@ def write_log( { "score": score, "total loss": total_loss, - "actor loss": loss[0] * delayed_update, + "actor loss": loss[0] * policy_update_freq, "qf_1 loss": loss[1], "qf_2 loss": loss[2], "vf loss": loss[3], @@ -359,7 +363,10 @@ def train(self): if loss_episode: avg_loss = np.vstack(loss_episode).mean(axis=0) self.write_log( - self.i_episode, avg_loss, score, self.hyper_params["DELAYED_UPDATE"] + self.i_episode, + avg_loss, + score, + self.hyper_params["POLICY_UPDATE_FREQ"], ) if self.i_episode % self.args.save_period == 0: diff --git a/algorithms/td3/agent.py b/algorithms/td3/agent.py index ec380e51..1239f601 100644 --- a/algorithms/td3/agent.py +++ b/algorithms/td3/agent.py @@ -222,7 +222,7 @@ def save_params(self, n_episode: int): AbstractAgent.save_params(self, params, n_episode) def write_log( - self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1 + self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1 ): """Write log about loss and score""" total_loss = loss.sum() @@ -235,7 +235,7 @@ def write_log( self.episode_steps, self.total_steps, total_loss, - loss[0] * delayed_update, # actor loss + loss[0] * policy_update_freq, # actor loss loss[1], # critic1 loss loss[2], # critic2 loss ) @@ -246,7 +246,7 @@ def write_log( { "score": score, "total loss": total_loss, - "actor loss": loss[0] * delayed_update, + "actor loss": loss[0] * policy_update_freq, "critic1 loss": loss[1], "critic2 loss": loss[2], } diff --git a/examples/lunarlander_continuous_v2/bc-sac.py b/examples/lunarlander_continuous_v2/bc-sac.py index a2fbac7b..89b938ab 100644 --- a/examples/lunarlander_continuous_v2/bc-sac.py +++ b/examples/lunarlander_continuous_v2/bc-sac.py @@ -31,7 +31,7 @@ "LR_QF1": 3e-4, "LR_QF2": 3e-4, "LR_ENTROPY": 3e-4, - "DELAYED_UPDATE": 2, + "POLICY_UPDATE_FREQ": 2, "BUFFER_SIZE": int(1e6), "BATCH_SIZE": 512, "DEMO_BATCH_SIZE": 64, diff --git a/examples/lunarlander_continuous_v2/sac.py b/examples/lunarlander_continuous_v2/sac.py index ae5a4cfa..1ffd501a 100644 --- a/examples/lunarlander_continuous_v2/sac.py +++ b/examples/lunarlander_continuous_v2/sac.py @@ -22,20 +22,20 @@ "GAMMA": 0.99, "TAU": 5e-3, "W_ENTROPY": 1e-3, - "W_MEAN_REG": 1e-3, - "W_STD_REG": 1e-3, + "W_MEAN_REG": 0.0, + "W_STD_REG": 0.0, "W_PRE_ACTIVATION_REG": 0.0, "LR_ACTOR": 3e-4, "LR_VF": 3e-4, "LR_QF1": 3e-4, "LR_QF2": 3e-4, "LR_ENTROPY": 3e-4, - "DELAYED_UPDATE": 2, + "POLICY_UPDATE_FREQ": 2, "BUFFER_SIZE": int(1e6), - "BATCH_SIZE": 512, + "BATCH_SIZE": 128, "AUTO_ENTROPY_TUNING": True, "WEIGHT_DECAY": 0.0, - "INITIAL_RANDOM_ACTION": 5000, + "INITIAL_RANDOM_ACTION": int(1e4), "MULTIPLE_LEARN": 1, } diff --git a/examples/lunarlander_continuous_v2/sacfd.py b/examples/lunarlander_continuous_v2/sacfd.py index 82fdbe4a..2c41469e 100644 --- a/examples/lunarlander_continuous_v2/sacfd.py +++ b/examples/lunarlander_continuous_v2/sacfd.py @@ -34,7 +34,7 @@ "W_MEAN_REG": 1e-3, "W_STD_REG": 1e-3, "W_PRE_ACTIVATION_REG": 0.0, - "DELAYED_UPDATE": 2, + "POLICY_UPDATE_FREQ": 2, "PRETRAIN_STEP": 100, "MULTIPLE_LEARN": 2, # multiple learning updates "LAMBDA1": 1.0, # N-step return weight diff --git a/examples/reacher_v2/bc-sac.py b/examples/reacher_v2/bc-sac.py index 2d1558a6..95b78cb2 100644 --- a/examples/reacher_v2/bc-sac.py +++ b/examples/reacher_v2/bc-sac.py @@ -31,7 +31,7 @@ "LR_QF1": 3e-4, "LR_QF2": 3e-4, "LR_ENTROPY": 3e-4, - "DELAYED_UPDATE": 2, + "POLICY_UPDATE_FREQ": 2, "BUFFER_SIZE": int(1e6), "BATCH_SIZE": 512, "DEMO_BATCH_SIZE": 64, diff --git a/examples/reacher_v2/sac.py b/examples/reacher_v2/sac.py index 1f37124f..8afdb305 100644 --- a/examples/reacher_v2/sac.py +++ b/examples/reacher_v2/sac.py @@ -30,7 +30,7 @@ "LR_QF1": 3e-4, "LR_QF2": 3e-4, "LR_ENTROPY": 3e-4, - "DELAYED_UPDATE": 2, + "POLICY_UPDATE_FREQ": 2, "BUFFER_SIZE": int(1e6), "BATCH_SIZE": 512, "AUTO_ENTROPY_TUNING": True,