Skip to content

Commit

Permalink
Fix sac performance degradation on lunarlander (#141)
Browse files Browse the repository at this point in the history
* Fix sac performance degradation on lunarlander

* Fix typo

* Tune hyperparams

* Tune the hyperparameters
  • Loading branch information
Curt-Park authored Apr 11, 2019
1 parent 791ebd7 commit ad9665f
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 25 deletions.
13 changes: 7 additions & 6 deletions algorithms/bc/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]):

def update_model(self) -> Tuple[torch.Tensor, ...]:
"""Train the model after each episode."""
experiences = self.memory.sample()
demos = self.demo_memory.sample()
self.update_step += 1

experiences, demos = self.memory.sample(), self.demo_memory.sample()

states, actions, rewards, next_states, dones = experiences
demo_states, demo_actions, _, _, _ = demos
Expand Down Expand Up @@ -170,7 +171,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]:
vf_loss.backward()
self.vf_optimizer.step()

if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0:
if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0:
# bc loss
qf_mask = torch.gt(
self.qf_1(demo_states, demo_actions),
Expand Down Expand Up @@ -224,7 +225,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]:
)

def write_log(
self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1
self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1
):
"""Write log about loss and score"""
total_loss = loss.sum()
Expand All @@ -239,7 +240,7 @@ def write_log(
self.total_step,
score,
total_loss,
loss[0] * delayed_update, # actor loss
loss[0] * policy_update_freq, # actor loss
loss[1], # qf_1 loss
loss[2], # qf_2 loss
loss[3], # vf loss
Expand All @@ -253,7 +254,7 @@ def write_log(
{
"score": score,
"total loss": total_loss,
"actor loss": loss[0] * delayed_update,
"actor loss": loss[0] * policy_update_freq,
"qf_1 loss": loss[1],
"qf_2 loss": loss[2],
"vf loss": loss[3],
Expand Down
9 changes: 7 additions & 2 deletions algorithms/fd/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]):
# pylint: disable=too-many-statements
def update_model(self) -> Tuple[torch.Tensor, ...]:
"""Train the model after each episode."""
self.update_step += 1

experiences = self.memory.sample(self.beta)
states, actions, rewards, next_states, dones, weights, indices, eps_d = (
experiences
Expand Down Expand Up @@ -149,7 +151,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]:
vf_loss.backward()
self.vf_optimizer.step()

if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0:
if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0:
# actor loss
advantage = q_pred - v_pred.detach()
actor_loss_element_wise = alpha * log_prob - advantage
Expand Down Expand Up @@ -212,6 +214,9 @@ def pretrain(self):
avg_loss = np.vstack(pretrain_loss).mean(axis=0)
pretrain_loss.clear()
self.write_log(
0, avg_loss, 0, delayed_update=self.hyper_params["DELAYED_UPDATE"]
0,
avg_loss,
0,
policy_update_freq=self.hyper_params["POLICY_UPDATE_FREQ"],
)
print("[INFO] Pre-Train Complete!\n")
17 changes: 12 additions & 5 deletions algorithms/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Agent(AbstractAgent):
hyper_params (dict): hyper-parameters
total_step (int): total step numbers
episode_step (int): step number of the current episode
update_step (int): step number of updates
i_episode (int): current episode number
"""
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
self.curr_state = np.zeros((1,))
self.total_step = 0
self.episode_step = 0
self.update_step = 0
self.i_episode = 0

# automatic entropy tuning
Expand Down Expand Up @@ -153,6 +155,8 @@ def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]):

def update_model(self) -> Tuple[torch.Tensor, ...]:
"""Train the model after each episode."""
self.update_step += 1

experiences = self.memory.sample()
states, actions, rewards, next_states, dones = experiences
new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states)
Expand Down Expand Up @@ -203,7 +207,7 @@ def update_model(self) -> Tuple[torch.Tensor, ...]:
vf_loss.backward()
self.vf_optimizer.step()

if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0:
if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0:
# actor loss
advantage = q_pred - v_pred.detach()
actor_loss = (alpha * log_prob - advantage).mean()
Expand Down Expand Up @@ -280,7 +284,7 @@ def save_params(self, n_episode: int):
AbstractAgent.save_params(self, params, n_episode)

def write_log(
self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1
self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1
):
"""Write log about loss and score"""
total_loss = loss.sum()
Expand All @@ -295,7 +299,7 @@ def write_log(
self.total_step,
score,
total_loss,
loss[0] * delayed_update, # actor loss
loss[0] * policy_update_freq, # actor loss
loss[1], # qf_1 loss
loss[2], # qf_2 loss
loss[3], # vf loss
Expand All @@ -308,7 +312,7 @@ def write_log(
{
"score": score,
"total loss": total_loss,
"actor loss": loss[0] * delayed_update,
"actor loss": loss[0] * policy_update_freq,
"qf_1 loss": loss[1],
"qf_2 loss": loss[2],
"vf loss": loss[3],
Expand Down Expand Up @@ -359,7 +363,10 @@ def train(self):
if loss_episode:
avg_loss = np.vstack(loss_episode).mean(axis=0)
self.write_log(
self.i_episode, avg_loss, score, self.hyper_params["DELAYED_UPDATE"]
self.i_episode,
avg_loss,
score,
self.hyper_params["POLICY_UPDATE_FREQ"],
)

if self.i_episode % self.args.save_period == 0:
Expand Down
6 changes: 3 additions & 3 deletions algorithms/td3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def save_params(self, n_episode: int):
AbstractAgent.save_params(self, params, n_episode)

def write_log(
self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1
self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1
):
"""Write log about loss and score"""
total_loss = loss.sum()
Expand All @@ -235,7 +235,7 @@ def write_log(
self.episode_steps,
self.total_steps,
total_loss,
loss[0] * delayed_update, # actor loss
loss[0] * policy_update_freq, # actor loss
loss[1], # critic1 loss
loss[2], # critic2 loss
)
Expand All @@ -246,7 +246,7 @@ def write_log(
{
"score": score,
"total loss": total_loss,
"actor loss": loss[0] * delayed_update,
"actor loss": loss[0] * policy_update_freq,
"critic1 loss": loss[1],
"critic2 loss": loss[2],
}
Expand Down
2 changes: 1 addition & 1 deletion examples/lunarlander_continuous_v2/bc-sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"LR_QF1": 3e-4,
"LR_QF2": 3e-4,
"LR_ENTROPY": 3e-4,
"DELAYED_UPDATE": 2,
"POLICY_UPDATE_FREQ": 2,
"BUFFER_SIZE": int(1e6),
"BATCH_SIZE": 512,
"DEMO_BATCH_SIZE": 64,
Expand Down
10 changes: 5 additions & 5 deletions examples/lunarlander_continuous_v2/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@
"GAMMA": 0.99,
"TAU": 5e-3,
"W_ENTROPY": 1e-3,
"W_MEAN_REG": 1e-3,
"W_STD_REG": 1e-3,
"W_MEAN_REG": 0.0,
"W_STD_REG": 0.0,
"W_PRE_ACTIVATION_REG": 0.0,
"LR_ACTOR": 3e-4,
"LR_VF": 3e-4,
"LR_QF1": 3e-4,
"LR_QF2": 3e-4,
"LR_ENTROPY": 3e-4,
"DELAYED_UPDATE": 2,
"POLICY_UPDATE_FREQ": 2,
"BUFFER_SIZE": int(1e6),
"BATCH_SIZE": 512,
"BATCH_SIZE": 128,
"AUTO_ENTROPY_TUNING": True,
"WEIGHT_DECAY": 0.0,
"INITIAL_RANDOM_ACTION": 5000,
"INITIAL_RANDOM_ACTION": int(1e4),
"MULTIPLE_LEARN": 1,
}

Expand Down
2 changes: 1 addition & 1 deletion examples/lunarlander_continuous_v2/sacfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"W_MEAN_REG": 1e-3,
"W_STD_REG": 1e-3,
"W_PRE_ACTIVATION_REG": 0.0,
"DELAYED_UPDATE": 2,
"POLICY_UPDATE_FREQ": 2,
"PRETRAIN_STEP": 100,
"MULTIPLE_LEARN": 2, # multiple learning updates
"LAMBDA1": 1.0, # N-step return weight
Expand Down
2 changes: 1 addition & 1 deletion examples/reacher_v2/bc-sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"LR_QF1": 3e-4,
"LR_QF2": 3e-4,
"LR_ENTROPY": 3e-4,
"DELAYED_UPDATE": 2,
"POLICY_UPDATE_FREQ": 2,
"BUFFER_SIZE": int(1e6),
"BATCH_SIZE": 512,
"DEMO_BATCH_SIZE": 64,
Expand Down
2 changes: 1 addition & 1 deletion examples/reacher_v2/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"LR_QF1": 3e-4,
"LR_QF2": 3e-4,
"LR_ENTROPY": 3e-4,
"DELAYED_UPDATE": 2,
"POLICY_UPDATE_FREQ": 2,
"BUFFER_SIZE": int(1e6),
"BATCH_SIZE": 512,
"AUTO_ENTROPY_TUNING": True,
Expand Down

0 comments on commit ad9665f

Please sign in to comment.