Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] JAXPolicy prep PR #2 (move get_activation_fn, minor fixes and preparations). #13091

Merged
merged 1 commit into from
Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/agents/ddpg/ddpg_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.framework import get_activation_fn

tf1, tf, tfv = try_import_tf()

Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/sac/sac_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.typing import ModelConfigDict, TensorType

Expand Down
3 changes: 2 additions & 1 deletion rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

tf1, tf, tfv = try_import_tf()
Expand Down
6 changes: 3 additions & 3 deletions rllib/models/tf/layers/noisy_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np

from ray.rllib.utils.framework import get_activation_fn, get_variable, \
try_import_tf
from ray.rllib.utils.framework import TensorType, TensorShape
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import get_variable, try_import_tf, \
TensorType, TensorShape

tf1, tf, tfv = try_import_tf()

Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()
Expand Down
3 changes: 2 additions & 1 deletion rllib/models/torch/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
from typing import Union, Tuple, Any, List

from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/modules/convtranspose2d_stack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Tuple

from ray.rllib.models.torch.misc import Reshape
from ray.rllib.models.utils import get_initializer
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn, get_initializer
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
if torch:
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/modules/noisy_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.utils.framework import TensorType
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch, TensorType

torch, nn = try_import_torch()

Expand Down
81 changes: 79 additions & 2 deletions rllib/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,62 @@
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from typing import Optional

from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch


def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
"""Returns a framework specific activation function, given a name string.

Args:
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
"linear" or None.
framework (str): One of "jax", "tf|tfe|tf2" or "torch".

Returns:
A framework-specific activtion function. e.g. tf.nn.tanh or
torch.nn.ReLU. None if name in ["linear", None].

Raises:
ValueError: If name is an unknown activation function.
"""
# Already a callable, return as-is.
if callable(name):
return name

# Infer the correct activation function from the string specifier.
if framework == "torch":
if name in ["linear", None]:
return None
if name == "swish":
from ray.rllib.utils.torch_ops import Swish
return Swish
_, nn = try_import_torch()
if name == "relu":
return nn.ReLU
elif name == "tanh":
return nn.Tanh
elif framework == "jax":
if name in ["linear", None]:
return None
jax, _ = try_import_jax()
if name == "swish":
return jax.nn.swish
if name == "relu":
return jax.nn.relu
elif name == "tanh":
return jax.nn.hard_tanh
else:
assert framework in ["tf", "tfe", "tf2"],\
"Unsupported framework `{}`!".format(framework)
if name in ["linear", None]:
return None
tf1, tf, tfv = try_import_tf()
fn = getattr(tf.nn, name, None)
if fn is not None:
return fn

raise ValueError("Unknown activation ({}) for framework={}!".format(
name, framework))


def get_filter_config(shape):
Expand Down Expand Up @@ -40,7 +98,7 @@ def get_initializer(name, framework="tf"):

Args:
name (str): One of "xavier_uniform" (default), "xavier_normal".
framework (str): One of "tf" or "torch".
framework (str): One of "jax", "tf|tfe|tf2" or "torch".

Returns:
A framework-specific initializer function, e.g.
Expand All @@ -50,14 +108,33 @@ def get_initializer(name, framework="tf"):
Raises:
ValueError: If name is an unknown initializer.
"""
# Already a callable, return as-is.
if callable(name):
return name

if framework == "jax":
_, flax = try_import_jax()
assert flax is not None,\
"`flax` not installed. Try `pip install jax flax`."
import flax.linen as nn
if name in [None, "default", "xavier_uniform"]:
return nn.initializers.xavier_uniform()
elif name == "xavier_normal":
return nn.initializers.xavier_normal()
if framework == "torch":
_, nn = try_import_torch()
assert nn is not None,\
"`torch` not installed. Try `pip install torch`."
if name in [None, "default", "xavier_uniform"]:
return nn.init.xavier_uniform_
elif name == "xavier_normal":
return nn.init.xavier_normal_
else:
assert framework in ["tf", "tfe", "tf2"],\
"Unsupported framework `{}`!".format(framework)
tf1, tf, tfv = try_import_tf()
assert tf is not None,\
"`tensorflow` not installed. Try `pip install tensorflow`."
if name in [None, "default", "xavier_uniform"]:
return tf.keras.initializers.GlorotUniform
elif name == "xavier_normal":
Expand Down
19 changes: 15 additions & 4 deletions rllib/tests/run_regression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,25 @@
import ray
from ray.tune import run_experiments
from ray.rllib import _register_all
from ray.rllib.utils.deprecation import deprecation_warning

parser = argparse.ArgumentParser()
parser.add_argument(
"--torch",
action="store_true",
help="Runs all tests with PyTorch enabled.")
"--framework",
choices=["jax", "tf2", "tf", "tfe", "torch"],
default="tf",
help="The deep learning framework to use.")
parser.add_argument(
"--yaml-dir",
type=str,
help="The directory in which to find all yamls to test.")

# Obsoleted arg, use --framework=torch instead.
parser.add_argument(
"--torch",
action="store_true",
help="Runs all tests with PyTorch enabled.")

if __name__ == "__main__":
args = parser.parse_args()

Expand Down Expand Up @@ -69,8 +77,11 @@

# Add torch option to exp configs.
for exp in experiments.values():
exp["config"]["framework"] = args.framework
if args.torch:
deprecation_warning(old="--torch", new="--framework=torch")
exp["config"]["framework"] = "torch"
args.framework = "torch"

# Print out the actual config.
print("== Test config ==")
Expand All @@ -82,7 +93,7 @@
for i in range(3):
try:
ray.init(num_cpus=5)
trials = run_experiments(experiments, resume=False, verbose=1)
trials = run_experiments(experiments, resume=False, verbose=2)
finally:
ray.shutdown()
_register_all()
Expand Down
4 changes: 2 additions & 2 deletions rllib/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def deprecation_warning(old, new=None, error=None):
Args:
old (str): A description of the "thing" that is to be deprecated.
new (Optional[str]): A description of the new "thing" that replaces it.
error (Optional[bool,Exception]): Whether or which exception to throw.
If True, throw ValueError.
error (Optional[Union[bool,Exception]]): Whether or which exception to
throw. If True, throw ValueError.
"""
msg = "`{}` has been deprecated.{}".format(
old, (" Use `{}` instead.".format(new) if new else ""))
Expand Down
3 changes: 2 additions & 1 deletion rllib/utils/exploration/curiosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchMultiCategorical
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import NullContextManager
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import get_activation_fn, try_import_tf, \
from ray.rllib.utils.framework import try_import_tf, \
try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot
Expand Down
7 changes: 6 additions & 1 deletion rllib/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Optional

from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,7 +253,7 @@ def get_variable(value,
return value


# TODO: (sven) move to models/utils.py
# Deprecated: Use rllib.models.utils::get_activation_fn instead.
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
"""Returns a framework specific activation function, given a name string.

Expand All @@ -268,6 +269,10 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
Raises:
ValueError: If name is an unknown activation function.
"""
deprecation_warning(
"rllib/utils/framework.py::get_activation_fn",
"rllib/models/utils.py::get_activation_fn",
error=False)
if framework == "torch":
if name in ["linear", None]:
return None
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch

jax, flax = try_import_jax()
jax, _ = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
eager_mode = None
Expand Down