Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix accumulation of results in (new API stack) Algorithm. #48136

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
490e254
wip
sven1977 Oct 17, 2024
0c8fb9e
wip
sven1977 Oct 18, 2024
8d46658
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 18, 2024
12185ca
wip
sven1977 Oct 18, 2024
6e4652b
ppo reporting everything ok now.
sven1977 Oct 18, 2024
ef549b8
fix episodes/episodes-lifetime in env runners.
sven1977 Oct 18, 2024
1171ccf
wip
sven1977 Oct 18, 2024
85c48e8
wip
sven1977 Oct 19, 2024
bd5a884
wip
sven1977 Oct 21, 2024
937ff49
wip
sven1977 Oct 21, 2024
4673c96
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 21, 2024
a686fc7
wip
sven1977 Oct 21, 2024
a6fcc37
wip
sven1977 Oct 21, 2024
b6ef29e
wip
sven1977 Oct 22, 2024
666ba01
wip
sven1977 Oct 22, 2024
70939e7
wip
sven1977 Oct 22, 2024
18fbb91
wip
sven1977 Oct 22, 2024
dbf2d07
fix
sven1977 Oct 22, 2024
75f761f
fix
sven1977 Oct 23, 2024
6c1aa7a
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 23, 2024
64c09e4
wip
sven1977 Oct 24, 2024
d0969d6
wip
sven1977 Oct 24, 2024
244cf40
wip
sven1977 Oct 24, 2024
056e043
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 24, 2024
cb76eaf
wip
sven1977 Oct 24, 2024
c9eabac
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 25, 2024
a06fbcc
wip
sven1977 Oct 25, 2024
c64932e
fixes
sven1977 Oct 25, 2024
1809ad3
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 28, 2024
e9ae0a9
wip
sven1977 Oct 28, 2024
61b8af0
wip
sven1977 Oct 28, 2024
54579d5
wip
sven1977 Oct 28, 2024
8443dcb
Revert "Revert "[RLlib] Upgrade to gymnasium 1.0.0 (ale_py 0.10.1, mu…
sven1977 Oct 29, 2024
ab2b22c
wip
sven1977 Oct 29, 2024
43bd52f
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 29, 2024
716c241
Merge branch 'revert-48297-revert-45328-upgrade_gymnasium_to_1_0_0a1'…
sven1977 Oct 29, 2024
a967fd4
wip
sven1977 Oct 29, 2024
bc17c93
wip
sven1977 Oct 29, 2024
17c6bad
wip
sven1977 Oct 30, 2024
ee208a0
wip
sven1977 Oct 30, 2024
bef9e1f
wip
sven1977 Oct 30, 2024
43b9ba6
wip
sven1977 Oct 30, 2024
317875a
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 31, 2024
b2aebd1
wip
sven1977 Oct 31, 2024
c403ffe
wip
sven1977 Oct 31, 2024
bde9583
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 31, 2024
e576ebe
wip
sven1977 Oct 31, 2024
7396518
wip
sven1977 Oct 31, 2024
3ff57ae
learns Pong-v5 on 1 (local) GPU and 46 env runners in ~6-7min.
sven1977 Oct 31, 2024
8afddb4
wip
sven1977 Nov 1, 2024
ced8703
fix
sven1977 Nov 1, 2024
8148259
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 1, 2024
a98568a
fix
sven1977 Nov 1, 2024
5e29b1f
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 1, 2024
dde1132
fix
sven1977 Nov 1, 2024
db4641c
fix
sven1977 Nov 1, 2024
aa9c578
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 1, 2024
20efe00
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 2, 2024
5b979f7
wip
sven1977 Nov 4, 2024
157060f
wip
sven1977 Nov 4, 2024
0c09e74
wip
sven1977 Nov 5, 2024
3602517
fixes
sven1977 Nov 5, 2024
97cb2a8
merge
sven1977 Nov 5, 2024
8574688
fix
sven1977 Nov 5, 2024
c674cd7
fix
sven1977 Nov 5, 2024
f3c0352
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 5, 2024
051c3bc
wip
sven1977 Nov 5, 2024
cebbec1
fix
sven1977 Nov 5, 2024
fa07017
fix
sven1977 Nov 6, 2024
0e34fd9
merge
sven1977 Nov 6, 2024
07faf22
fix
sven1977 Nov 6, 2024
e49f8d6
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 6, 2024
a1f68b1
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 6, 2024
0b59a96
Merge branch 'fix_impala_gpu_loader_thread_and_local_learner' into fi…
sven1977 Nov 6, 2024
8b75db9
merge
sven1977 Nov 6, 2024
fa63e33
fix
sven1977 Nov 6, 2024
277e057
wip
sven1977 Nov 6, 2024
8fae002
fix
sven1977 Nov 6, 2024
cfde0c4
Merge branch 'fix_impala_gpu_loader_thread_and_local_learner' into fi…
sven1977 Nov 6, 2024
308e161
fix
sven1977 Nov 6, 2024
ce26d74
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 7, 2024
037cd78
wip
sven1977 Nov 7, 2024
3f31afa
wip
sven1977 Nov 7, 2024
89048b6
Merge branch 'fix_impala_gpu_loader_thread_and_local_learner' into fi…
sven1977 Nov 7, 2024
d8eaf0e
wip
sven1977 Nov 7, 2024
c197c78
merge
sven1977 Nov 7, 2024
131d1db
wip
sven1977 Nov 7, 2024
fda8e9f
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 7, 2024
2a6229c
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 8, 2024
1a4214f
wip
sven1977 Nov 8, 2024
5abf033
merge
sven1977 Nov 10, 2024
9a11f52
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Nov 11, 2024
7b71481
wip
sven1977 Nov 11, 2024
5ee3df5
wip
sven1977 Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,214 changes: 626 additions & 588 deletions rllib/algorithms/algorithm.py

Large diffs are not rendered by default.

92 changes: 40 additions & 52 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.metrics import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
ResultDict,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -288,60 +285,51 @@ def __init__(self, config, *args, **kwargs):
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

@override(IMPALA)
def training_step(self) -> ResultDict:
train_results = super().training_step()
def training_step(self) -> None:
if self.config.enable_rl_module_and_learner:
return super().training_step()

train_results = super().training_step()
# Update the target network and the KL coefficient for the APPO-loss.
# The target network update frequency is calculated automatically by the product
# of `num_epochs` setting (usually 1 for APPO) and `minibatch_buffer_size`.
if self.config.enable_rl_module_and_learner:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more hybrid stack

if NUM_TARGET_UPDATES in train_results:
self._counters[NUM_TARGET_UPDATES] += train_results[NUM_TARGET_UPDATES]
self._counters[LAST_TARGET_UPDATE_TS] = train_results[
LAST_TARGET_UPDATE_TS
]
else:
last_update = self._counters[LAST_TARGET_UPDATE_TS]
cur_ts = self._counters[
(
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
)
]
target_update_freq = (
self.config.num_epochs * self.config.minibatch_buffer_size
last_update = self._counters[LAST_TARGET_UPDATE_TS]
cur_ts = self._counters[
(
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
)
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning(
"No data for {}, not updating kl".format(pi_id)
)

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.env_runner.foreach_policy_to_train(update)
]
target_update_freq = self.config.num_epochs * self.config.minibatch_buffer_size
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.env_runner.foreach_policy_to_train(update)

return train_results

Expand Down
9 changes: 2 additions & 7 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
from ray.rllib.utils.typing import RLModuleSpecType


class BCConfig(MARWILConfig):
Expand Down Expand Up @@ -113,15 +113,10 @@ def validate(self) -> None:
class BC(MARWIL):
"""Behavioral Cloning (derived from MARWIL).

Simply uses MARWIL with beta force-set to 0.0.
Uses MARWIL with beta force-set to 0.0.
"""

@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return BCConfig()

@override(MARWIL)
def training_step(self) -> ResultDict:
# Call MARWIL's training step.
return super().training_step()
33 changes: 5 additions & 28 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Expand All @@ -39,9 +39,6 @@
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
NUM_TARGET_UPDATES,
OFFLINE_SAMPLING_TIMER,
TARGET_NET_UPDATE_TIMER,
Expand Down Expand Up @@ -311,14 +308,11 @@ def get_default_policy_class(
return CQLTFPolicy

@override(SAC)
def training_step(self) -> ResultDict:
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack()
else:
def training_step(self) -> None:
# Old API stack (Policy, RolloutWorker, Connector).
if not self.config.enable_env_runner_and_connector_v2:
return self._training_step_old_api_stack()

def _training_step_new_api_stack(self) -> ResultDict:

# Sampling from offline data.
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# Return an iterator in case we are using remote learners.
Expand All @@ -340,22 +334,6 @@ def _training_step_new_api_stack(self) -> ResultDict:

# Log training results.
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
self.metrics.log_value(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these bulky blocks inside each algo's training_step() are no longer necessary

NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
},
reduce="sum",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome change!


# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand All @@ -374,8 +352,7 @@ def _training_step_new_api_stack(self) -> ResultDict:
inference_only=True,
)

return self.metrics.reduce()

@OldAPIStack
def _training_step_old_api_stack(self) -> ResultDict:
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
Expand Down
92 changes: 22 additions & 70 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_EPISODES,
NUM_EPISODES_LIFETIME,
NUM_MODULE_STEPS_SAMPLED,
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
NUM_TARGET_UPDATES,
REPLAY_BUFFER_ADD_DATA_TIMER,
REPLAY_BUFFER_SAMPLE_TIMER,
Expand Down Expand Up @@ -642,7 +634,7 @@ def get_default_policy_class(
return DQNTFPolicy

@override(Algorithm)
def training_step(self) -> ResultDict:
def training_step(self) -> None:
"""DQN training iteration function.

Each training iteration, we:
Expand All @@ -657,14 +649,14 @@ def training_step(self) -> ResultDict:
Returns:
The results dict from executing the training iteration.
"""
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack(with_noise_reset=True)
# Old API stack (Policy, RolloutWorker).
else:
# Old API stack (Policy, RolloutWorker, Connector).
if not self.config.enable_env_runner_and_connector_v2:
return self._training_step_old_api_stack()

def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
return self._training_step_new_api_stack(with_noise_reset=True)

def _training_step_new_api_stack(self, *, with_noise_reset):
# Alternate between storing and sampling and training.
store_weight, sample_and_train_weight = calculate_rr_weights(self.config)

Expand All @@ -688,38 +680,16 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)):
self.local_replay_buffer.add(episodes)

self.metrics.log_dict(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these bulky blocks inside each algo's training_step() are no longer necessary

self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={}
),
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
self.metrics.log_value(
NUM_ENV_STEPS_SAMPLED_LIFETIME,
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0),
reduce="sum",
)
self.metrics.log_value(
NUM_EPISODES_LIFETIME,
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0),
reduce="sum",
)
self.metrics.log_dict(
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED),
default={},
),
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)

if self.config.count_steps_by == "agent_steps":
current_ts = sum(
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME).values()
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME)
).values()
)
else:
current_ts = self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
current_ts = self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
)

# If enough experiences have been sampled start training.
if current_ts >= self.config.num_steps_sampled_before_learning_starts:
Expand Down Expand Up @@ -750,10 +720,17 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
episodes=episodes,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
)
),
NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME)
self.metrics.peek(
(
ENV_RUNNER_RESULTS,
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
)
)
),
},
)
Expand All @@ -775,29 +752,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
self.metrics.merge_and_log_n_dicts(
learner_results, key=LEARNER_RESULTS
)
self.metrics.log_value(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these bulky blocks inside each algo's training_step() are no longer necessary

NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
if NUM_MODULE_STEPS_TRAINED in stats
},
reduce="sum",
)

# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
# (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={}
# ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum")

# Update replay buffer priorities.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
Expand All @@ -818,8 +772,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
inference_only=True,
)

return self.metrics.reduce()

def _training_step_old_api_stack(self) -> ResultDict:
"""Training step for the old API stack.

Expand Down
Loading