Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the serialization bug of rectified adam. #1375

Merged
merged 3 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions tensorflow_addons/optimizers/rectified_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Rectified Adam (RAdam) optimizer."""
import warnings

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
Expand Down Expand Up @@ -79,7 +80,10 @@ def __init__(
weight_decay: FloatTensorLike = 0.0,
amsgrad: bool = False,
sma_threshold: FloatTensorLike = 5.0,
total_steps: int = 0,
# float for total_steps is here to be able to load models created before
# https://github.com/tensorflow/addons/pull/1375 was merged. It should be
# removed for Addons 0.11.
total_steps: Union[int, float] = 0,
warmup_proportion: FloatTensorLike = 0.1,
min_lr: FloatTensorLike = 0.0,
name: str = "RectifiedAdam",
Expand Down Expand Up @@ -123,7 +127,16 @@ def __init__(
self._set_hyper("decay", self._initial_decay)
self._set_hyper("weight_decay", weight_decay)
self._set_hyper("sma_threshold", sma_threshold)
self._set_hyper("total_steps", float(total_steps))
if isinstance(total_steps, float):
warnings.warn(
"The parameter `total_steps` passed to the __init__ of RectifiedAdam "
"is a float. This behavior is deprecated and in Addons 0.11, this "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to select 0.11? Also, how do we keep track of this when we get to version 0.11?

Copy link
Member Author

@gabrieldemarmiesse gabrieldemarmiesse Mar 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to give a date in the warning, otherwise users won't care. They'll just say: "I'll leave tomorrow's probems to tomorrow's me". :P

No idea how to track it though. Milestones in github?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Milestones require a due date, not version.

Why not provide a deprecation date instead of a version? But this directly means we'll need to create a new release on that due date.

Copy link
Member Author

@gabrieldemarmiesse gabrieldemarmiesse Mar 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When stopping support for a branch, it's indeed common to specify a date: ex the python org will stop support for the branch 2.7 of python january 1st 2020.

When deprecating a feature and removing it later, it's more common to use a specific version:
/opt/conda/lib/python3.7/site-packages/tensorflow/python/training/tracking/data_structures.py:718: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working.

I believe it makes more sense to use a specific version. We can make a milestone. We're not forced to add a date. For example:
https://github.com/pytest-dev/pytest/milestones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! I'll accept this PR. We should create a milestone for this as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, accidentally created a duplicate issue (#1381, closed now).

"will raise an error. Use a int instead. If you get this message "
Squadrick marked this conversation as resolved.
Show resolved Hide resolved
gabrieldemarmiesse marked this conversation as resolved.
Show resolved Hide resolved
"when loading a model, save it again and the `total_steps` parameter "
"will automatically be converted to a int.",
DeprecationWarning,
)
self._set_hyper("total_steps", int(total_steps))
self._set_hyper("warmup_proportion", warmup_proportion)
self._set_hyper("min_lr", min_lr)
self.epsilon = epsilon or tf.keras.backend.epsilon()
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_addons/optimizers/rectified_adam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,14 @@ def test_get_config(self):
self.assertEqual(config["total_steps"], 0)


def test_serialization():
optimizer = RectifiedAdam(
lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5,
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


if __name__ == "__main__":
sys.exit(pytest.main([__file__]))