Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent activation functions between backends #3431

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@
_DICT_VAL = TypeVar("_DICT_VAL")
_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
"relu", "relu6", "softplus", "sigmoid", "tanh", "gelu", "gelu_tf"
"relu",
"relu6",
"softplus",
"sigmoid",
"tanh",
"gelu",
"gelu_tf",
"none",
"linear",
]
__all__.extend(
[
Expand Down
59 changes: 51 additions & 8 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
datetime,
)
from typing import (
Callable,
ClassVar,
Dict,
List,
Expand Down Expand Up @@ -309,14 +310,7 @@
"""
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
if self.activation_function == "tanh":
fn = np.tanh
elif self.activation_function.lower() == "none":

def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
fn = get_activation_fn(self.activation_function)
y = (
np.matmul(x, self.w) + self.b
if self.b is not None
Expand All @@ -332,6 +326,55 @@
return y


def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
activation_function = activation_function.lower()
if activation_function == "tanh":
return np.tanh
elif activation_function == "relu":

def fn(x):
# https://stackoverflow.com/a/47936476/9567349
return x * (x > 0)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x):
# generated by GitHub Copilot
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

return fn
elif activation_function == "relu6":

def fn(x):
# generated by GitHub Copilot
return np.minimum(np.maximum(x, 0), 6)

return fn
elif activation_function == "softplus":

def fn(x):
# generated by GitHub Copilot
return np.log(1 + np.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x):
# generated by GitHub Copilot
return 1 / (1 + np.exp(-x))

return fn
elif activation_function.lower() in ("none", "linear"):

def fn(x):
return x

return fn
else:
raise NotImplementedError(activation_function)

Check warning on line 375 in deepmd/dpmodel/utils/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/network.py#L375

Added line #L375 was not covered by tests


def make_multilayer_network(T_NetworkLayer, ModuleBase):
class NN(ModuleBase):
"""Native representation of a neural network.
Expand Down
20 changes: 16 additions & 4 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`."""
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu":
return F.gelu
elif activation.lower() == "gelu" or activation.lower() == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation.lower() == "tanh":
return torch.tanh
elif activation.lower() == "relu6":
return F.relu6
elif activation.lower() == "softplus":
return F.softplus
elif activation.lower() == "sigmoid":
return torch.sigmoid
elif activation.lower() == "linear" or activation.lower() == "none":
return lambda x: x
else:
Expand All @@ -42,10 +48,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

if self.activation.lower() == "relu":
return F.relu(x)
elif self.activation.lower() == "gelu":
return F.gelu(x)
elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf":
return F.gelu(x, approximate="tanh")
elif self.activation.lower() == "tanh":
return torch.tanh(x)
elif self.activation.lower() == "relu6":
return F.relu6(x)
elif self.activation.lower() == "softplus":
return F.softplus(x)
elif self.activation.lower() == "sigmoid":
return torch.sigmoid(x)
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
return x
else:
Expand Down
13 changes: 7 additions & 6 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def gelu_wrapper(x):
"tanh": tf.nn.tanh,
"gelu": gelu,
"gelu_tf": gelu_tf,
"None": None,
"none": None,
"linear": lambda x: x,
"none": lambda x: x,
}


def get_activation_func(
activation_fn: Union["_ACTIVATION", None],
) -> Union[Callable[[tf.Tensor], tf.Tensor], None]:
) -> Callable[[tf.Tensor], tf.Tensor]:
"""Get activation function callable based on string name.

Parameters
Expand All @@ -161,10 +161,11 @@ def get_activation_func(
if unknown activation function is specified
"""
if activation_fn is None:
return None
if activation_fn not in ACTIVATION_FN_DICT:
activation_fn = "none"
assert activation_fn is not None
if activation_fn.lower() not in ACTIVATION_FN_DICT:
raise RuntimeError(f"{activation_fn} is not a valid activation function")
return ACTIVATION_FN_DICT[activation_fn]
return ACTIVATION_FN_DICT[activation_fn.lower()]


def get_precision(precision: "_PRECISION") -> Any:
Expand Down
63 changes: 63 additions & 0 deletions source/tests/consistent/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp

from .common import (
INSTALLED_PT,
INSTALLED_TF,
parameterized,
)

if INSTALLED_PT:
from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)
if INSTALLED_TF:
from deepmd.tf.common import get_activation_func as get_activation_fn_tf
from deepmd.tf.env import (
tf,
)


@parameterized(
(
"Relu",
"Relu6",
"Softplus",
"Sigmoid",
"Tanh",
"Gelu",
"Gelu_tf",
"Linear",
"None",
),
)
class TestActivationFunctionConsistent(unittest.TestCase):
def setUp(self):
(self.activation,) = self.param
self.random_input = np.random.default_rng().normal(scale=10, size=(10, 10))
self.ref = get_activation_fn_dp(self.activation)(self.random_input)

@unittest.skipUnless(INSTALLED_TF, "TensorFlow is not installed")
def test_tf_consistent_with_ref(self):
if INSTALLED_TF:
place_holder = tf.placeholder(tf.float64, self.random_input.shape)
t_test = get_activation_fn_tf(self.activation)(place_holder)
with tf.Session() as sess:
test = sess.run(t_test, feed_dict={place_holder: self.random_input})
np.testing.assert_allclose(self.ref, test, atol=1e-10)

@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
def test_pt_consistent_with_ref(self):
if INSTALLED_PT:
test = to_numpy_array(
get_activation_fn_pt(self.activation)(
to_torch_tensor(self.random_input)
)
)
np.testing.assert_allclose(self.ref, test, atol=1e-10)