From 2a817990cdd8fe526c383078d760a1949942d66a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 28 Dec 2020 09:57:46 -0500 Subject: [PATCH] WIP. --- rllib/agents/ddpg/ddpg_torch_model.py | 3 +- rllib/agents/maml/maml_tf_policy.py | 2 +- rllib/agents/sac/sac_torch_model.py | 3 +- rllib/models/tf/fcnet.py | 3 +- rllib/models/tf/layers/noisy_layer.py | 6 +- rllib/models/tf/visionnet.py | 4 +- rllib/models/torch/misc.py | 3 +- .../torch/modules/convtranspose2d_stack.py | 4 +- rllib/models/torch/modules/noisy_layer.py | 4 +- rllib/models/utils.py | 81 ++++++++++++++++++- rllib/tests/run_regression_tests.py | 19 ++++- rllib/utils/deprecation.py | 4 +- rllib/utils/exploration/curiosity.py | 3 +- rllib/utils/framework.py | 7 +- rllib/utils/test_utils.py | 2 +- 15 files changed, 123 insertions(+), 25 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_torch_model.py b/rllib/agents/ddpg/ddpg_torch_model.py index 66d910ebf07f..f3108c855771 100644 --- a/rllib/agents/ddpg/ddpg_torch_model.py +++ b/rllib/agents/ddpg/ddpg_torch_model.py @@ -2,7 +2,8 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.framework import get_activation_fn, try_import_torch +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/agents/maml/maml_tf_policy.py b/rllib/agents/maml/maml_tf_policy.py index d07de1495f3c..b9e4d0775820 100644 --- a/rllib/agents/maml/maml_tf_policy.py +++ b/rllib/agents/maml/maml_tf_policy.py @@ -5,10 +5,10 @@ vf_preds_fetches, compute_and_clip_gradients, setup_config, \ ValueNetworkMixin from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.utils import get_activation_fn from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf -from ray.rllib.utils.framework import get_activation_fn tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 9ebb8c75f99d..5f8b05980fed 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -5,7 +5,8 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.framework import get_activation_fn, try_import_torch +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.typing import ModelConfigDict, TensorType diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index 0f72c546bc15..e556741ddd22 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -3,7 +3,8 @@ from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils.framework import get_activation_fn, try_import_tf +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/layers/noisy_layer.py b/rllib/models/tf/layers/noisy_layer.py index 49b11e0c62e9..4498995e0226 100644 --- a/rllib/models/tf/layers/noisy_layer.py +++ b/rllib/models/tf/layers/noisy_layer.py @@ -1,8 +1,8 @@ import numpy as np -from ray.rllib.utils.framework import get_activation_fn, get_variable, \ - try_import_tf -from ray.rllib.utils.framework import TensorType, TensorShape +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import get_variable, try_import_tf, \ + TensorType, TensorShape tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index e09668b49396..c2a8de5d2c97 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -3,8 +3,8 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.misc import normc_initializer -from ray.rllib.models.utils import get_filter_config -from ray.rllib.utils.framework import get_activation_fn, try_import_tf +from ray.rllib.models.utils import get_activation_fn, get_filter_config +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py index 307d7644179e..830e8bc33b5e 100644 --- a/rllib/models/torch/misc.py +++ b/rllib/models/torch/misc.py @@ -2,7 +2,8 @@ import numpy as np from typing import Union, Tuple, Any, List -from ray.rllib.utils.framework import get_activation_fn, try_import_torch +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() diff --git a/rllib/models/torch/modules/convtranspose2d_stack.py b/rllib/models/torch/modules/convtranspose2d_stack.py index 0eab6dc7cf16..4fc1735a6700 100644 --- a/rllib/models/torch/modules/convtranspose2d_stack.py +++ b/rllib/models/torch/modules/convtranspose2d_stack.py @@ -1,8 +1,8 @@ from typing import Tuple from ray.rllib.models.torch.misc import Reshape -from ray.rllib.models.utils import get_initializer -from ray.rllib.utils.framework import get_activation_fn, try_import_torch +from ray.rllib.models.utils import get_activation_fn, get_initializer +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() if torch: diff --git a/rllib/models/torch/modules/noisy_layer.py b/rllib/models/torch/modules/noisy_layer.py index ee553c73c89b..f980dba0412a 100644 --- a/rllib/models/torch/modules/noisy_layer.py +++ b/rllib/models/torch/modules/noisy_layer.py @@ -1,7 +1,7 @@ import numpy as np -from ray.rllib.utils.framework import get_activation_fn, try_import_torch -from ray.rllib.utils.framework import TensorType +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_torch, TensorType torch, nn = try_import_torch() diff --git a/rllib/models/utils.py b/rllib/models/utils.py index 2c9f076f0ebe..ed50ce08c986 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -1,4 +1,62 @@ -from ray.rllib.utils.framework import try_import_tf, try_import_torch +from typing import Optional + +from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ + try_import_torch + + +def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): + """Returns a framework specific activation function, given a name string. + + Args: + name (Optional[str]): One of "relu" (default), "tanh", "swish", or + "linear" or None. + framework (str): One of "jax", "tf|tfe|tf2" or "torch". + + Returns: + A framework-specific activtion function. e.g. tf.nn.tanh or + torch.nn.ReLU. None if name in ["linear", None]. + + Raises: + ValueError: If name is an unknown activation function. + """ + # Already a callable, return as-is. + if callable(name): + return name + + # Infer the correct activation function from the string specifier. + if framework == "torch": + if name in ["linear", None]: + return None + if name == "swish": + from ray.rllib.utils.torch_ops import Swish + return Swish + _, nn = try_import_torch() + if name == "relu": + return nn.ReLU + elif name == "tanh": + return nn.Tanh + elif framework == "jax": + if name in ["linear", None]: + return None + jax, _ = try_import_jax() + if name == "swish": + return jax.nn.swish + if name == "relu": + return jax.nn.relu + elif name == "tanh": + return jax.nn.hard_tanh + else: + assert framework in ["tf", "tfe", "tf2"],\ + "Unsupported framework `{}`!".format(framework) + if name in ["linear", None]: + return None + tf1, tf, tfv = try_import_tf() + fn = getattr(tf.nn, name, None) + if fn is not None: + return fn + + raise ValueError("Unknown activation ({}) for framework={}!".format( + name, framework)) def get_filter_config(shape): @@ -40,7 +98,7 @@ def get_initializer(name, framework="tf"): Args: name (str): One of "xavier_uniform" (default), "xavier_normal". - framework (str): One of "tf" or "torch". + framework (str): One of "jax", "tf|tfe|tf2" or "torch". Returns: A framework-specific initializer function, e.g. @@ -50,14 +108,33 @@ def get_initializer(name, framework="tf"): Raises: ValueError: If name is an unknown initializer. """ + # Already a callable, return as-is. + if callable(name): + return name + + if framework == "jax": + _, flax = try_import_jax() + assert flax is not None,\ + "`flax` not installed. Try `pip install jax flax`." + import flax.linen as nn + if name in [None, "default", "xavier_uniform"]: + return nn.initializers.xavier_uniform() + elif name == "xavier_normal": + return nn.initializers.xavier_normal() if framework == "torch": _, nn = try_import_torch() + assert nn is not None,\ + "`torch` not installed. Try `pip install torch`." if name in [None, "default", "xavier_uniform"]: return nn.init.xavier_uniform_ elif name == "xavier_normal": return nn.init.xavier_normal_ else: + assert framework in ["tf", "tfe", "tf2"],\ + "Unsupported framework `{}`!".format(framework) tf1, tf, tfv = try_import_tf() + assert tf is not None,\ + "`tensorflow` not installed. Try `pip install tensorflow`." if name in [None, "default", "xavier_uniform"]: return tf.keras.initializers.GlorotUniform elif name == "xavier_normal": diff --git a/rllib/tests/run_regression_tests.py b/rllib/tests/run_regression_tests.py index 9a2c3313779b..3f42147e4071 100644 --- a/rllib/tests/run_regression_tests.py +++ b/rllib/tests/run_regression_tests.py @@ -25,17 +25,25 @@ import ray from ray.tune import run_experiments from ray.rllib import _register_all +from ray.rllib.utils.deprecation import deprecation_warning parser = argparse.ArgumentParser() parser.add_argument( - "--torch", - action="store_true", - help="Runs all tests with PyTorch enabled.") + "--framework", + choices=["jax", "tf2", "tf", "tfe", "torch"], + default="tf", + help="The deep learning framework to use.") parser.add_argument( "--yaml-dir", type=str, help="The directory in which to find all yamls to test.") +# Obsoleted arg, use --framework=torch instead. +parser.add_argument( + "--torch", + action="store_true", + help="Runs all tests with PyTorch enabled.") + if __name__ == "__main__": args = parser.parse_args() @@ -69,8 +77,11 @@ # Add torch option to exp configs. for exp in experiments.values(): + exp["config"]["framework"] = args.framework if args.torch: + deprecation_warning(old="--torch", new="--framework=torch") exp["config"]["framework"] = "torch" + args.framework = "torch" # Print out the actual config. print("== Test config ==") @@ -82,7 +93,7 @@ for i in range(3): try: ray.init(num_cpus=5) - trials = run_experiments(experiments, resume=False, verbose=1) + trials = run_experiments(experiments, resume=False, verbose=2) finally: ray.shutdown() _register_all() diff --git a/rllib/utils/deprecation.py b/rllib/utils/deprecation.py index 8f3828b6a15b..05788059bed1 100644 --- a/rllib/utils/deprecation.py +++ b/rllib/utils/deprecation.py @@ -15,8 +15,8 @@ def deprecation_warning(old, new=None, error=None): Args: old (str): A description of the "thing" that is to be deprecated. new (Optional[str]): A description of the new "thing" that replaces it. - error (Optional[bool,Exception]): Whether or which exception to throw. - If True, throw ValueError. + error (Optional[Union[bool,Exception]]): Whether or which exception to + throw. If True, throw ValueError. """ msg = "`{}` has been deprecated.{}".format( old, (" Use `{}` instead.".format(new) if new else "")) diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index a9434e1a1174..ec91c53d39f5 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -9,11 +9,12 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchMultiCategorical +from ray.rllib.models.utils import get_activation_fn from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import NullContextManager from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration -from ray.rllib.utils.framework import get_activation_fn, try_import_tf, \ +from ray.rllib.utils.framework import try_import_tf, \ try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 8c7d223074b6..e7323ff6c79a 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -4,6 +4,7 @@ import sys from typing import Any, Optional +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType logger = logging.getLogger(__name__) @@ -252,7 +253,7 @@ def get_variable(value, return value -# TODO: (sven) move to models/utils.py +# Deprecated: Use rllib.models.utils::get_activation_fn instead. def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): """Returns a framework specific activation function, given a name string. @@ -268,6 +269,10 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): Raises: ValueError: If name is an unknown activation function. """ + deprecation_warning( + "rllib/utils/framework.py::get_activation_fn", + "rllib/models/utils.py::get_activation_fn", + error=False) if framework == "torch": if name in ["linear", None]: return None diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 5460d9c277f4..d0263eff5668 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -5,7 +5,7 @@ from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch -jax, flax = try_import_jax() +jax, _ = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None