Skip to content

Commit

Permalink
[RLlib] Remove vtrace_drop_last_ts option and add proper vf bootstr…
Browse files Browse the repository at this point in the history
…apping to IMPALA and APPO. (#36013)
  • Loading branch information
sven1977 authored Jun 22, 2023
1 parent 66535d5 commit e14c9b1
Show file tree
Hide file tree
Showing 20 changed files with 321 additions and 211 deletions.
12 changes: 11 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,16 @@ py_test(
args = ["--dir=tuned_examples/appo"]
)

py_test(
name = "learning_tests_stateless_cartpole_appo_vtrace",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/stateless-cartpole-appo-vtrace.py"],
args = ["--dir=tuned_examples/appo"]
)

# ARS
py_test(
name = "learning_tests_cartpole_ars",
Expand Down Expand Up @@ -3710,7 +3720,7 @@ py_test(
)

# Taking out this test for now: Mixed torch- and tf- policies within the same
# Algorothm never really worked.
# Algorithm never really worked.
# py_test(
# name = "examples/multi_agent_two_trainers_mixed_torch_tf",
# main = "examples/multi_agent_two_trainers.py",
Expand Down
74 changes: 29 additions & 45 deletions rllib/algorithms/appo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_bootstrap_value,
compute_gae_for_sample_batch,
Postprocessing,
)
Expand Down Expand Up @@ -144,7 +145,6 @@ def loss(
is_multidiscrete = False
output_hidden_shape = 1

# TODO: (sven) deprecate this when trajectory view API gets activated.
def make_time_major(*args, **kw):
return _make_time_major(
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
Expand All @@ -159,12 +159,16 @@ def make_time_major(*args, **kw):
prev_action_dist = dist_class(behaviour_logits, self.model)
values = self.model.value_function()
values_time_major = make_time_major(values)
bootstrap_values_time_major = make_time_major(
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
)
bootstrap_value = bootstrap_values_time_major[-1]

if self.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
mask = make_time_major(mask, drop_last=self.config["vtrace"])
mask = make_time_major(mask)

def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, mask))
Expand All @@ -173,11 +177,7 @@ def reduce_mean_valid(t):
reduce_mean_valid = tf.reduce_mean

if self.config["vtrace"]:
drop_last = self.config["vtrace_drop_last_ts"]
logger.debug(
"Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})"
)
logger.debug("Using V-Trace surrogate loss (vtrace=True)")

# Prepare actions for loss.
loss_actions = (
Expand All @@ -188,9 +188,7 @@ def reduce_mean_valid(t):
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last
)
mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist))

unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1
Expand All @@ -203,26 +201,20 @@ def reduce_mean_valid(t):
with tf.device("/cpu:0"):
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=make_time_major(
unpacked_behaviour_logits, drop_last=drop_last
unpacked_behaviour_logits
),
target_policy_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=drop_last
),
actions=tf.unstack(
make_time_major(loss_actions, drop_last=drop_last), axis=2
unpacked_old_policy_behaviour_logits
),
actions=tf.unstack(make_time_major(loss_actions), axis=2),
discounts=tf.cast(
~make_time_major(
tf.cast(dones, tf.bool), drop_last=drop_last
),
~make_time_major(tf.cast(dones, tf.bool)),
tf.float32,
)
* self.config["gamma"],
rewards=make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1]
if drop_last
else values_time_major,
bootstrap_value=values_time_major[-1],
rewards=make_time_major(rewards),
values=values_time_major,
bootstrap_value=bootstrap_value,
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
clip_rho_threshold=tf.cast(
Expand All @@ -233,14 +225,10 @@ def reduce_mean_valid(t):
),
)

actions_logp = make_time_major(
action_dist.logp(actions), drop_last=drop_last
)
prev_actions_logp = make_time_major(
prev_action_dist.logp(actions), drop_last=drop_last
)
actions_logp = make_time_major(action_dist.logp(actions))
prev_actions_logp = make_time_major(prev_action_dist.logp(actions))
old_policy_actions_logp = make_time_major(
old_policy_action_dist.logp(actions), drop_last=drop_last
old_policy_action_dist.logp(actions)
)

is_ratio = tf.clip_by_value(
Expand All @@ -267,17 +255,12 @@ def reduce_mean_valid(t):
mean_policy_loss = -reduce_mean_valid(surrogate_loss)

# The value function loss.
if drop_last:
delta = values_time_major[:-1] - vtrace_returns.vs
else:
delta = values_time_major - vtrace_returns.vs
value_targets = vtrace_returns.vs
delta = values_time_major - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

# The entropy loss.
actions_entropy = make_time_major(
action_dist.multi_entropy(), drop_last=True
)
actions_entropy = make_time_major(action_dist.multi_entropy())
mean_entropy = reduce_mean_valid(actions_entropy)

else:
Expand Down Expand Up @@ -353,7 +336,6 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
self,
train_batch.get(SampleBatch.SEQ_LENS),
self.model.value_function(),
drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"],
)

stats_dict = {
Expand Down Expand Up @@ -388,20 +370,22 @@ def postprocess_trajectory(
other_agent_batches: Optional[SampleBatch] = None,
episode: Optional["Episode"] = None,
):
# Call super's postprocess_trajectory first.
# sample_batch = super().postprocess_trajectory(
# sample_batch, other_agent_batches, episode
# )

if not self.config["vtrace"]:
sample_batch = compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)
else:
# Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need
# inside the loss for vtrace calculations.
sample_batch = compute_bootstrap_value(sample_batch, self)

return sample_batch

@override(base)
def extra_action_out_fn(self) -> Dict[str, TensorType]:
extra_action_fetches = super().extra_action_out_fn()
if not self.config["vtrace"]:
extra_action_fetches[SampleBatch.VF_PREDS] = self.model.value_function()
return extra_action_fetches

@override(base)
def get_batch_divisibility_req(self) -> int:
return self.config["rollout_fragment_length"]
Expand Down
85 changes: 36 additions & 49 deletions rllib/algorithms/appo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_bootstrap_value,
compute_gae_for_sample_batch,
Postprocessing,
)
Expand Down Expand Up @@ -157,14 +158,16 @@ def _make_time_major(*args, **kwargs):
prev_action_dist = dist_class(behaviour_logits, model)
values = model.value_function()
values_time_major = _make_time_major(values)

drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"]
bootstrap_values_time_major = _make_time_major(
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
)
bootstrap_value = bootstrap_values_time_major[-1]

if self.is_recurrent():
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = torch.reshape(mask, [-1])
mask = _make_time_major(mask, drop_last=drop_last)
mask = _make_time_major(mask)
num_valid = torch.sum(mask)

def reduce_mean_valid(t):
Expand All @@ -174,9 +177,7 @@ def reduce_mean_valid(t):
reduce_mean_valid = torch.mean

if self.config["vtrace"]:
logger.debug(
"Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
)
logger.debug("Using V-Trace surrogate loss (vtrace=True)")

old_policy_behaviour_logits = target_model_out.detach()
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
Expand All @@ -202,40 +203,30 @@ def reduce_mean_valid(t):
)

# Prepare KL for loss.
action_kl = _make_time_major(
old_policy_action_dist.kl(action_dist), drop_last=drop_last
)
action_kl = _make_time_major(old_policy_action_dist.kl(action_dist))

# Compute vtrace on the CPU for better perf.
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=drop_last
),
behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits),
target_policy_logits=_make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=drop_last
unpacked_old_policy_behaviour_logits
),
actions=torch.unbind(
_make_time_major(loss_actions, drop_last=drop_last), dim=2
),
discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
actions=torch.unbind(_make_time_major(loss_actions), dim=2),
discounts=(1.0 - _make_time_major(dones).float())
* self.config["gamma"],
rewards=_make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1] if drop_last else values_time_major,
bootstrap_value=values_time_major[-1],
rewards=_make_time_major(rewards),
values=values_time_major,
bootstrap_value=bootstrap_value,
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
)

actions_logp = _make_time_major(
action_dist.logp(actions), drop_last=drop_last
)
prev_actions_logp = _make_time_major(
prev_action_dist.logp(actions), drop_last=drop_last
)
actions_logp = _make_time_major(action_dist.logp(actions))
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
old_policy_actions_logp = _make_time_major(
old_policy_action_dist.logp(actions), drop_last=drop_last
old_policy_action_dist.logp(actions)
)
is_ratio = torch.clamp(
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
Expand All @@ -259,16 +250,11 @@ def reduce_mean_valid(t):

# The value function loss.
value_targets = vtrace_returns.vs.to(values_time_major.device)
if drop_last:
delta = values_time_major[:-1] - value_targets
else:
delta = values_time_major - value_targets
delta = values_time_major - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

# The entropy loss.
mean_entropy = reduce_mean_valid(
_make_time_major(action_dist.entropy(), drop_last=drop_last)
)
mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))

else:
logger.debug("Using PPO surrogate loss (vtrace=False)")
Expand Down Expand Up @@ -323,9 +309,7 @@ def reduce_mean_valid(t):
model.tower_stats["value_targets"] = value_targets
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(value_targets, [-1]),
torch.reshape(
values_time_major[:-1] if drop_last else values_time_major, [-1]
),
torch.reshape(values_time_major, [-1]),
)

return total_loss
Expand Down Expand Up @@ -378,10 +362,7 @@ def extra_action_out(
model: TorchModelV2,
action_dist: TorchDistributionWrapper,
) -> Dict[str, TensorType]:
out = {}
if not self.config["vtrace"]:
out[SampleBatch.VF_PREDS] = model.value_function()
return out
return {SampleBatch.VF_PREDS: model.value_function()}

@override(TorchPolicyV2)
def postprocess_trajectory(
Expand All @@ -391,17 +372,23 @@ def postprocess_trajectory(
episode: Optional["Episode"] = None,
):
# Call super's postprocess_trajectory first.
sample_batch = super().postprocess_trajectory(
sample_batch, other_agent_batches, episode
)
if not self.config["vtrace"]:
# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak
# in torch (issue #6962).
with torch.no_grad():
# sample_batch = super().postprocess_trajectory(
# sample_batch, other_agent_batches, episode
# )

# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak
# in torch (issue #6962).
with torch.no_grad():
if not self.config["vtrace"]:
sample_batch = compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)
else:
# Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need
# inside the loss for vtrace calculations.
sample_batch = compute_bootstrap_value(sample_batch, self)

return sample_batch

@override(TorchPolicyV2)
Expand Down
3 changes: 3 additions & 0 deletions rllib/algorithms/appo/tests/test_appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
SampleBatch.VF_PREDS: np.array(
list(reversed(range(frag_length))), dtype=np.float32
),
SampleBatch.VALUES_BOOTSTRAPPED: np.array(
list(reversed(range(frag_length))), dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.log(
np.random.uniform(low=0, high=1, size=(frag_length,))
).astype(np.float32),
Expand Down
13 changes: 9 additions & 4 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,24 @@ def compute_loss_for_module(
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
rewards_time_major = make_time_major(
batch[SampleBatch.REWARDS],
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
values_time_major = make_time_major(
values,
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
bootstrap_value = values_time_major[-1]
rewards_time_major = make_time_major(
batch[SampleBatch.REWARDS],
bootstrap_values_time_major = make_time_major(
batch[SampleBatch.VALUES_BOOTSTRAPPED],
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
bootstrap_value = bootstrap_values_time_major[-1]

# the discount factor that is used should be gamma except for timesteps where
# The discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
discounts_time_major = (
1.0
Expand Down
Loading

0 comments on commit e14c9b1

Please sign in to comment.