Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[no_early_kickoff] [AIR] Deprecate TensorflowCheckpoint.get_model model_definition parameter #33776

Merged
merged 32 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5984f1e
Rename `"model_state_dict"` to `"model"`
bveeramani Nov 22, 2022
8f58490
Revert "Rename `"model_state_dict"` to `"model"`"
bveeramani Nov 22, 2022
63432eb
Merge remote-tracking branch 'upstream/master'
bveeramani Dec 5, 2022
2c33947
Merge remote-tracking branch 'upstream/master'
bveeramani Dec 6, 2022
89694a0
Merge remote-tracking branch 'upstream/master'
bveeramani Dec 19, 2022
fe60ca3
Merge remote-tracking branch 'upstream/master'
bveeramani Dec 27, 2022
d45ae9a
Merge remote-tracking branch 'upstream/master'
bveeramani Jan 2, 2023
c703dfc
Merge remote-tracking branch 'upstream/master'
bveeramani Jan 6, 2023
81dd25c
Merge remote-tracking branch 'upstream/master'
bveeramani Jan 19, 2023
fba788e
Merge remote-tracking branch 'upstream/master'
bveeramani Jan 24, 2023
de05655
Update annotations.py
bveeramani Jan 26, 2023
fd2ff91
Revert "Update annotations.py"
bveeramani Jan 26, 2023
7c3ac36
Merge remote-tracking branch 'upstream/master'
bveeramani Feb 14, 2023
9c7b546
Merge remote-tracking branch 'upstream/master'
bveeramani Feb 28, 2023
9d7a15f
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 2, 2023
4f02efe
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 2, 2023
46f43d9
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 6, 2023
a644307
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 10, 2023
e5d0bf1
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 20, 2023
fe7286c
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 21, 2023
51b3a09
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 24, 2023
7271109
Initial commit
bveeramani Mar 27, 2023
8447018
Format `test_keras_callback.py`
bveeramani Mar 27, 2023
79f8be9
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 28, 2023
5b0e6fc
Merge branch 'master' into deprecate-model-definition
bveeramani Mar 28, 2023
78d0d0d
Fix tests
bveeramani Mar 28, 2023
d583e48
Update test_tensorflow_checkpoint.py
bveeramani Mar 28, 2023
20dcaa9
Create temp.py
bveeramani Mar 28, 2023
224aa51
Delete temp.py
bveeramani Mar 28, 2023
fb372c8
Update test_tensorflow_checkpoint.py
bveeramani Mar 28, 2023
53bede6
Merge remote-tracking branch 'upstream/master'
bveeramani Mar 28, 2023
c756b23
Merge branch 'master' into deprecate-model-definition
bveeramani Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/ray/train/tensorflow/tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ def get_model(
Returns:
The Tensorflow Keras model stored in the checkpoint.
"""
# TODO: Remove `model_definition` in 2.6.
if model_definition is not None:
warnings.warn(
raise DeprecationWarning(
"The `model_definition` parameter is deprecated. Use the `model` "
"parameter instead.",
DeprecationWarning,
"parameter instead."
)
model = model_definition

if model is not None and self._flavor is not self.Flavor.MODEL_WEIGHTS:
warnings.warn(
"TensorflowCheckpoint was created from "
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_model():
def test_model_definition_raises_deprecation_warning():
model = get_model()
checkpoint = TensorflowCheckpoint.from_model(model)
with pytest.deprecated_call():
with pytest.raises(DeprecationWarning):
checkpoint.get_model(model_definition=get_model)


Expand Down