Skip to content

Commit

Permalink
[RLlib] Fix accumulation of results in (new API stack) Algorithm. (#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Nov 11, 2024
1 parent 21308bc commit 3141dfe
Show file tree
Hide file tree
Showing 26 changed files with 1,403 additions and 1,342 deletions.
7 changes: 4 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ py_test(
srcs = ["core/rl_module/tests/test_rl_module_specs.py"]
)

# Learner
# LearnerGroup
py_test(
name = "TestLearnerGroupSyncUpdate",
main = "core/learner/tests/test_learner_group.py",
Expand Down Expand Up @@ -1473,16 +1473,17 @@ py_test(
args = ["TestLearnerGroupSaveLoadState"]
)

# Learner
py_test(
name = "test_learner",
tags = ["team:rllib", "core", "ray_data"],
tags = ["team:rllib", "core", "ray_data", "exclusive"],
size = "medium",
srcs = ["core/learner/tests/test_learner.py"]
)

py_test(
name = "test_torch_learner_compile",
tags = ["team:rllib", "core", "ray_data"],
tags = ["team:rllib", "core", "ray_data", "exclusive"],
size = "medium",
srcs = ["core/learner/torch/tests/test_torch_learner_compile.py"]
)
Expand Down
1,262 changes: 671 additions & 591 deletions rllib/algorithms/algorithm.py

Large diffs are not rendered by default.

92 changes: 40 additions & 52 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.metrics import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
ResultDict,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -291,60 +288,51 @@ def __init__(self, config, *args, **kwargs):
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

@override(IMPALA)
def training_step(self) -> ResultDict:
train_results = super().training_step()
def training_step(self) -> None:
if self.config.enable_rl_module_and_learner:
return super().training_step()

train_results = super().training_step()
# Update the target network and the KL coefficient for the APPO-loss.
# The target network update frequency is calculated automatically by the product
# of `num_epochs` setting (usually 1 for APPO) and `minibatch_buffer_size`.
if self.config.enable_rl_module_and_learner:
if NUM_TARGET_UPDATES in train_results:
self._counters[NUM_TARGET_UPDATES] += train_results[NUM_TARGET_UPDATES]
self._counters[LAST_TARGET_UPDATE_TS] = train_results[
LAST_TARGET_UPDATE_TS
]
else:
last_update = self._counters[LAST_TARGET_UPDATE_TS]
cur_ts = self._counters[
(
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
)
]
target_update_freq = (
self.config.num_epochs * self.config.minibatch_buffer_size
last_update = self._counters[LAST_TARGET_UPDATE_TS]
cur_ts = self._counters[
(
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
)
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning(
"No data for {}, not updating kl".format(pi_id)
)

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.env_runner.foreach_policy_to_train(update)
]
target_update_freq = self.config.num_epochs * self.config.minibatch_buffer_size
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.env_runner.foreach_policy_to_train(update)

return train_results

Expand Down
2 changes: 0 additions & 2 deletions rllib/algorithms/appo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,11 @@ def test_appo_two_optimizers_two_lrs(self):
algo.stop()

def test_appo_entropy_coeff_schedule(self):
# Initial lr, doesn't really matter because of the schedule below.
config = (
appo.APPOConfig()
.environment("CartPole-v1")
.env_runners(
num_env_runners=1,
batch_mode="truncate_episodes",
rollout_fragment_length=10,
)
.training(
Expand Down
9 changes: 2 additions & 7 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
from ray.rllib.utils.typing import RLModuleSpecType


class BCConfig(MARWILConfig):
Expand Down Expand Up @@ -113,15 +113,10 @@ def validate(self) -> None:
class BC(MARWIL):
"""Behavioral Cloning (derived from MARWIL).
Simply uses MARWIL with beta force-set to 0.0.
Uses MARWIL with beta force-set to 0.0.
"""

@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return BCConfig()

@override(MARWIL)
def training_step(self) -> ResultDict:
# Call MARWIL's training step.
return super().training_step()
33 changes: 5 additions & 28 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Expand All @@ -39,9 +39,6 @@
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
NUM_TARGET_UPDATES,
OFFLINE_SAMPLING_TIMER,
TARGET_NET_UPDATE_TIMER,
Expand Down Expand Up @@ -301,14 +298,11 @@ def get_default_policy_class(
return CQLTFPolicy

@override(SAC)
def training_step(self) -> ResultDict:
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack()
else:
def training_step(self) -> None:
# Old API stack (Policy, RolloutWorker, Connector).
if not self.config.enable_env_runner_and_connector_v2:
return self._training_step_old_api_stack()

def _training_step_new_api_stack(self) -> ResultDict:

# Sampling from offline data.
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# Return an iterator in case we are using remote learners.
Expand All @@ -330,22 +324,6 @@ def _training_step_new_api_stack(self) -> ResultDict:

# Log training results.
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
},
reduce="sum",
)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand All @@ -364,8 +342,7 @@ def _training_step_new_api_stack(self) -> ResultDict:
inference_only=True,
)

return self.metrics.reduce()

@OldAPIStack
def _training_step_old_api_stack(self) -> ResultDict:
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
Expand Down
92 changes: 22 additions & 70 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_EPISODES,
NUM_EPISODES_LIFETIME,
NUM_MODULE_STEPS_SAMPLED,
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
NUM_TARGET_UPDATES,
REPLAY_BUFFER_ADD_DATA_TIMER,
REPLAY_BUFFER_SAMPLE_TIMER,
Expand Down Expand Up @@ -640,7 +632,7 @@ def get_default_policy_class(
return DQNTFPolicy

@override(Algorithm)
def training_step(self) -> ResultDict:
def training_step(self) -> None:
"""DQN training iteration function.
Each training iteration, we:
Expand All @@ -655,14 +647,14 @@ def training_step(self) -> ResultDict:
Returns:
The results dict from executing the training iteration.
"""
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack(with_noise_reset=True)
# Old API stack (Policy, RolloutWorker).
else:
# Old API stack (Policy, RolloutWorker, Connector).
if not self.config.enable_env_runner_and_connector_v2:
return self._training_step_old_api_stack()

def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
return self._training_step_new_api_stack(with_noise_reset=True)

def _training_step_new_api_stack(self, *, with_noise_reset):
# Alternate between storing and sampling and training.
store_weight, sample_and_train_weight = calculate_rr_weights(self.config)

Expand All @@ -686,38 +678,16 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)):
self.local_replay_buffer.add(episodes)

self.metrics.log_dict(
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={}
),
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
self.metrics.log_value(
NUM_ENV_STEPS_SAMPLED_LIFETIME,
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0),
reduce="sum",
)
self.metrics.log_value(
NUM_EPISODES_LIFETIME,
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0),
reduce="sum",
)
self.metrics.log_dict(
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED),
default={},
),
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)

if self.config.count_steps_by == "agent_steps":
current_ts = sum(
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME).values()
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={}
).values()
)
else:
current_ts = self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
current_ts = self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
)

# If enough experiences have been sampled start training.
if current_ts >= self.config.num_steps_sampled_before_learning_starts:
Expand Down Expand Up @@ -748,10 +718,17 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
episodes=episodes,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
)
),
NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME)
self.metrics.peek(
(
ENV_RUNNER_RESULTS,
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
)
)
),
},
)
Expand All @@ -773,29 +750,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
self.metrics.merge_and_log_n_dicts(
learner_results, key=LEARNER_RESULTS
)
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
if NUM_MODULE_STEPS_TRAINED in stats
},
reduce="sum",
)

# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
# (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={}
# ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum")

# Update replay buffer priorities.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
Expand All @@ -816,8 +770,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
inference_only=True,
)

return self.metrics.reduce()

def _training_step_old_api_stack(self) -> ResultDict:
"""Training step for the old API stack.
Expand Down
Loading

0 comments on commit 3141dfe

Please sign in to comment.