Skip to content

Commit

Permalink
[no_early_kickoff] [AIR] Deprecate TensorflowCheckpoint.get_model `…
Browse files Browse the repository at this point in the history
…model_definition` parameter (#33776)

model_definition was deprecated in Ray 2.3 in favor of model. This PR escalates the deprecation warning to a deprecation error.

---------

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored Mar 28, 2023
1 parent 2988a38 commit 4310993
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/ray/train/tensorflow/tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ def get_model(
Returns:
The Tensorflow Keras model stored in the checkpoint.
"""
# TODO: Remove `model_definition` in 2.6.
if model_definition is not None:
warnings.warn(
raise DeprecationWarning(
"The `model_definition` parameter is deprecated. Use the `model` "
"parameter instead.",
DeprecationWarning,
"parameter instead."
)
model = model_definition

if model is not None and self._flavor is not self.Flavor.MODEL_WEIGHTS:
warnings.warn(
"TensorflowCheckpoint was created from "
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_model():
def test_model_definition_raises_deprecation_warning():
model = get_model()
checkpoint = TensorflowCheckpoint.from_model(model)
with pytest.deprecated_call():
with pytest.raises(DeprecationWarning):
checkpoint.get_model(model_definition=get_model)


Expand Down

0 comments on commit 4310993

Please sign in to comment.