-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Allow envs to be auto-registered; add on_train_result callback with curriculum example #3451
[rllib] Allow envs to be auto-registered; add on_train_result callback with curriculum example #3451
Changes from all commits
f6ade0d
94cd839
ede34ca
89d9c8e
4760437
2edb020
c613c05
be97fae
2616824
3737fd6
33dade4
1583ac3
3d4b5c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,20 +2,21 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from datetime import datetime | ||
import copy | ||
import os | ||
import logging | ||
import os | ||
import pickle | ||
import six | ||
import tempfile | ||
from datetime import datetime | ||
import tensorflow as tf | ||
|
||
import ray | ||
from ray.rllib.models import MODEL_DEFAULTS | ||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator | ||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer | ||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts | ||
from ray.tune.registry import ENV_CREATOR, _global_registry | ||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry | ||
from ray.tune.trainable import Trainable | ||
from ray.tune.trial import Resources | ||
from ray.tune.logger import UnifiedLogger | ||
|
@@ -40,6 +41,7 @@ | |
"on_episode_step": None, # arg: {"env": .., "episode": ...} | ||
"on_episode_end": None, # arg: {"env": .., "episode": ...} | ||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...} | ||
"on_train_result": None, # arg: {"agent": ..., "result": ...} | ||
}, | ||
|
||
# === Policy === | ||
|
@@ -274,7 +276,7 @@ def __init__(self, config=None, env=None, logger_creator=None): | |
self.global_vars = {"timestep": 0} | ||
|
||
# Agents allow env ids to be passed directly to the constructor. | ||
self._env_id = env or config.get("env") | ||
self._env_id = _register_if_needed(env or config.get("env")) | ||
|
||
# Create a default logger creator if no logger_creator is specified | ||
if logger_creator is None: | ||
|
@@ -316,7 +318,13 @@ def train(self): | |
logger.debug("synchronized filters: {}".format( | ||
self.local_evaluator.filters)) | ||
|
||
return Trainable.train(self) | ||
result = Trainable.train(self) | ||
if self.config["callbacks"].get("on_train_result"): | ||
self.config["callbacks"]["on_train_result"]({ | ||
"agent": self, | ||
"result": result, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. .copy()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (or deepcopy) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the advantage here is the user can mutate it if they want. |
||
}) | ||
return result | ||
|
||
def _setup(self, config): | ||
env = self._env_id | ||
|
@@ -444,6 +452,15 @@ def _restore(self, checkpoint_path): | |
self.__setstate__(extra_data) | ||
|
||
|
||
def _register_if_needed(env_object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this happen earlier? i.e., before the config is pickled, somewhere in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it really matters, since the pickled config can't be resolved anyways. Basically it just adds a little more text there. |
||
if isinstance(env_object, six.string_types): | ||
return env_object | ||
elif isinstance(env_object, type): | ||
name = env_object.__name__ | ||
register_env(name, lambda config: env_object(config)) | ||
return name | ||
|
||
|
||
def get_agent_class(alg): | ||
"""Returns the class of a known agent given its name.""" | ||
|
||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can probably put together a page full of examples (in a later PR)