-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix accumulation of results in (new API stack) Algorithm. #48136
Changes from 13 commits
490e254
0c8fb9e
8d46658
12185ca
6e4652b
ef549b8
1171ccf
85c48e8
bd5a884
937ff49
4673c96
a686fc7
a6fcc37
b6ef29e
666ba01
70939e7
18fbb91
dbf2d07
75f761f
6c1aa7a
64c09e4
d0969d6
244cf40
056e043
cb76eaf
c9eabac
a06fbcc
c64932e
1809ad3
e9ae0a9
61b8af0
54579d5
8443dcb
ab2b22c
43bd52f
716c241
a967fd4
bc17c93
17c6bad
ee208a0
bef9e1f
43b9ba6
317875a
b2aebd1
c403ffe
bde9583
e576ebe
7396518
3ff57ae
8afddb4
ced8703
8148259
a98568a
5e29b1f
dde1132
db4641c
aa9c578
20efe00
5b979f7
157060f
0c09e74
3602517
97cb2a8
8574688
c674cd7
f3c0352
051c3bc
cebbec1
fa07017
0e34fd9
07faf22
e49f8d6
a1f68b1
0b59a96
8b75db9
fa63e33
277e057
8fae002
cfde0c4
308e161
ce26d74
037cd78
3f31afa
89048b6
d8eaf0e
c197c78
131d1db
fda8e9f
2a6229c
1a4214f
5abf033
9a11f52
7b71481
5ee3df5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ | |
train_one_step, | ||
) | ||
from ray.rllib.policy.policy import Policy | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.annotations import OldAPIStack, override | ||
from ray.rllib.utils.deprecation import ( | ||
DEPRECATED_VALUE, | ||
deprecation_warning, | ||
|
@@ -39,9 +39,6 @@ | |
NUM_AGENT_STEPS_TRAINED, | ||
NUM_ENV_STEPS_SAMPLED, | ||
NUM_ENV_STEPS_TRAINED, | ||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
NUM_MODULE_STEPS_TRAINED, | ||
NUM_MODULE_STEPS_TRAINED_LIFETIME, | ||
NUM_TARGET_UPDATES, | ||
OFFLINE_SAMPLING_TIMER, | ||
TARGET_NET_UPDATE_TIMER, | ||
|
@@ -311,14 +308,11 @@ def get_default_policy_class( | |
return CQLTFPolicy | ||
|
||
@override(SAC) | ||
def training_step(self) -> ResultDict: | ||
if self.config.enable_env_runner_and_connector_v2: | ||
return self._training_step_new_api_stack() | ||
else: | ||
def training_step(self) -> None: | ||
# Old API stack (Policy, RolloutWorker, Connector). | ||
if not self.config.enable_env_runner_and_connector_v2: | ||
return self._training_step_old_api_stack() | ||
|
||
def _training_step_new_api_stack(self) -> ResultDict: | ||
|
||
# Sampling from offline data. | ||
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): | ||
# Return an iterator in case we are using remote learners. | ||
|
@@ -340,22 +334,6 @@ def _training_step_new_api_stack(self) -> ResultDict: | |
|
||
# Log training results. | ||
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) | ||
self.metrics.log_value( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these bulky blocks inside each algo's training_step() are no longer necessary |
||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
self.metrics.peek( | ||
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED) | ||
), | ||
reduce="sum", | ||
) | ||
self.metrics.log_dict( | ||
{ | ||
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): ( | ||
stats[NUM_MODULE_STEPS_TRAINED] | ||
) | ||
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items() | ||
}, | ||
reduce="sum", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome change! |
||
|
||
# Synchronize weights. | ||
# As the results contain for each policy the loss and in addition the | ||
|
@@ -374,8 +352,7 @@ def _training_step_new_api_stack(self) -> ResultDict: | |
inference_only=True, | ||
) | ||
|
||
return self.metrics.reduce() | ||
|
||
@OldAPIStack | ||
def _training_step_old_api_stack(self) -> ResultDict: | ||
# Collect SampleBatches from sample workers. | ||
with self._timers[SAMPLE_TIMER]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,14 +55,6 @@ | |
NUM_AGENT_STEPS_SAMPLED_LIFETIME, | ||
NUM_ENV_STEPS_SAMPLED, | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, | ||
NUM_ENV_STEPS_TRAINED, | ||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
NUM_EPISODES, | ||
NUM_EPISODES_LIFETIME, | ||
NUM_MODULE_STEPS_SAMPLED, | ||
NUM_MODULE_STEPS_SAMPLED_LIFETIME, | ||
NUM_MODULE_STEPS_TRAINED, | ||
NUM_MODULE_STEPS_TRAINED_LIFETIME, | ||
NUM_TARGET_UPDATES, | ||
REPLAY_BUFFER_ADD_DATA_TIMER, | ||
REPLAY_BUFFER_SAMPLE_TIMER, | ||
|
@@ -642,7 +634,7 @@ def get_default_policy_class( | |
return DQNTFPolicy | ||
|
||
@override(Algorithm) | ||
def training_step(self) -> ResultDict: | ||
def training_step(self) -> None: | ||
"""DQN training iteration function. | ||
|
||
Each training iteration, we: | ||
|
@@ -657,14 +649,14 @@ def training_step(self) -> ResultDict: | |
Returns: | ||
The results dict from executing the training iteration. | ||
""" | ||
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2). | ||
if self.config.enable_env_runner_and_connector_v2: | ||
return self._training_step_new_api_stack(with_noise_reset=True) | ||
# Old API stack (Policy, RolloutWorker). | ||
else: | ||
# Old API stack (Policy, RolloutWorker, Connector). | ||
if not self.config.enable_env_runner_and_connector_v2: | ||
return self._training_step_old_api_stack() | ||
|
||
def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | ||
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2). | ||
return self._training_step_new_api_stack(with_noise_reset=True) | ||
|
||
def _training_step_new_api_stack(self, *, with_noise_reset): | ||
# Alternate between storing and sampling and training. | ||
store_weight, sample_and_train_weight = calculate_rr_weights(self.config) | ||
|
||
|
@@ -688,38 +680,16 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | |
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)): | ||
self.local_replay_buffer.add(episodes) | ||
|
||
self.metrics.log_dict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these bulky blocks inside each algo's training_step() are no longer necessary |
||
self.metrics.peek( | ||
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={} | ||
), | ||
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME, | ||
reduce="sum", | ||
) | ||
self.metrics.log_value( | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, | ||
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0), | ||
reduce="sum", | ||
) | ||
self.metrics.log_value( | ||
NUM_EPISODES_LIFETIME, | ||
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0), | ||
reduce="sum", | ||
) | ||
self.metrics.log_dict( | ||
self.metrics.peek( | ||
(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED), | ||
default={}, | ||
), | ||
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME, | ||
reduce="sum", | ||
) | ||
|
||
if self.config.count_steps_by == "agent_steps": | ||
current_ts = sum( | ||
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME).values() | ||
self.metrics.peek( | ||
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME) | ||
).values() | ||
) | ||
else: | ||
current_ts = self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) | ||
current_ts = self.metrics.peek( | ||
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME) | ||
) | ||
|
||
# If enough experiences have been sampled start training. | ||
if current_ts >= self.config.num_steps_sampled_before_learning_starts: | ||
|
@@ -750,10 +720,17 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | |
episodes=episodes, | ||
timesteps={ | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME: ( | ||
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) | ||
self.metrics.peek( | ||
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME) | ||
) | ||
), | ||
NUM_AGENT_STEPS_SAMPLED_LIFETIME: ( | ||
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME) | ||
self.metrics.peek( | ||
( | ||
ENV_RUNNER_RESULTS, | ||
NUM_AGENT_STEPS_SAMPLED_LIFETIME, | ||
) | ||
) | ||
), | ||
}, | ||
) | ||
|
@@ -775,29 +752,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | |
self.metrics.merge_and_log_n_dicts( | ||
learner_results, key=LEARNER_RESULTS | ||
) | ||
self.metrics.log_value( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these bulky blocks inside each algo's training_step() are no longer necessary |
||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
self.metrics.peek( | ||
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED) | ||
), | ||
reduce="sum", | ||
) | ||
self.metrics.log_dict( | ||
{ | ||
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): ( | ||
stats[NUM_MODULE_STEPS_TRAINED] | ||
) | ||
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items() | ||
if NUM_MODULE_STEPS_TRAINED in stats | ||
}, | ||
reduce="sum", | ||
) | ||
|
||
# TODO (sven): Uncomment this once agent steps are available in the | ||
# Learner stats. | ||
# self.metrics.log_dict(self.metrics.peek( | ||
# (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={} | ||
# ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum") | ||
|
||
# Update replay buffer priorities. | ||
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)): | ||
|
@@ -818,8 +772,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | |
inference_only=True, | ||
) | ||
|
||
return self.metrics.reduce() | ||
|
||
def _training_step_old_api_stack(self) -> ResultDict: | ||
"""Training step for the old API stack. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no more hybrid stack