From 5984f1ec6f49d8a2e5ae3a362f9c71efe1c8c36e Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 21 Nov 2022 21:00:42 -0800 Subject: [PATCH 01/11] Rename `"model_state_dict"` to `"model"` --- .../train/examples/pytorch/tune_cifar_torch_pbt_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index 71472227a249..46ea6ab3947a 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -83,7 +83,7 @@ def train_func(config): checkpoint_dict = session.get_checkpoint().to_dict() # Load in model - model_state = checkpoint_dict["model_state_dict"] + model_state = checkpoint_dict["model"] model.load_state_dict(model_state) # Load in optimizer @@ -146,7 +146,7 @@ def train_func(config): checkpoint = Checkpoint.from_dict( { "epoch": epoch, - "model_state_dict": model.state_dict(), + "model": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } ) From 8f58490187070bc8122629f97f9c9ff07b716f85 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 21 Nov 2022 21:00:49 -0800 Subject: [PATCH 02/11] Revert "Rename `"model_state_dict"` to `"model"`" This reverts commit 5984f1ec6f49d8a2e5ae3a362f9c71efe1c8c36e. --- .../train/examples/pytorch/tune_cifar_torch_pbt_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index 46ea6ab3947a..71472227a249 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -83,7 +83,7 @@ def train_func(config): checkpoint_dict = session.get_checkpoint().to_dict() # Load in model - model_state = checkpoint_dict["model"] + model_state = checkpoint_dict["model_state_dict"] model.load_state_dict(model_state) # Load in optimizer @@ -146,7 +146,7 @@ def train_func(config): checkpoint = Checkpoint.from_dict( { "epoch": epoch, - "model": model.state_dict(), + "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } ) From de05655b003c96b3cb9194e6cf21155e04ee22f5 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 26 Jan 2023 11:56:49 -0800 Subject: [PATCH 03/11] Update annotations.py Signed-off-by: Balaji Veeramani --- python/ray/util/annotations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/util/annotations.py b/python/ray/util/annotations.py index 9996b092fcab..f7b93746f910 100644 --- a/python/ray/util/annotations.py +++ b/python/ray/util/annotations.py @@ -49,7 +49,7 @@ def PublicAPI(*args, **kwargs): def wrap(obj): if stability in ["alpha", "beta"]: message = ( - f"PublicAPI ({stability}): This API is in {stability} " + f"**PublicAPI ({stability}):** This API is in {stability} " "and may change before becoming stable." ) else: @@ -80,7 +80,8 @@ def DeveloperAPI(*args, **kwargs): def wrap(obj): _append_doc( - obj, message="DeveloperAPI: This API may change across minor Ray releases." + obj, + message="**DeveloperAPI:** This API may change across minor Ray releases.", ) _mark_annotated(obj) return obj From fd2ff917e1cc3258554c56b283db8e8e155cff9a Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 26 Jan 2023 12:02:30 -0800 Subject: [PATCH 04/11] Revert "Update annotations.py" This reverts commit de05655b003c96b3cb9194e6cf21155e04ee22f5. Signed-off-by: Balaji Veeramani --- python/ray/util/annotations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/util/annotations.py b/python/ray/util/annotations.py index f7b93746f910..9996b092fcab 100644 --- a/python/ray/util/annotations.py +++ b/python/ray/util/annotations.py @@ -49,7 +49,7 @@ def PublicAPI(*args, **kwargs): def wrap(obj): if stability in ["alpha", "beta"]: message = ( - f"**PublicAPI ({stability}):** This API is in {stability} " + f"PublicAPI ({stability}): This API is in {stability} " "and may change before becoming stable." ) else: @@ -80,8 +80,7 @@ def DeveloperAPI(*args, **kwargs): def wrap(obj): _append_doc( - obj, - message="**DeveloperAPI:** This API may change across minor Ray releases.", + obj, message="DeveloperAPI: This API may change across minor Ray releases." ) _mark_annotated(obj) return obj From 727110918dd4a7125ac6f6176f1bf58460b56706 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 16:43:16 -0700 Subject: [PATCH 05/11] Initial commit Signed-off-by: Balaji Veeramani --- python/ray/air/tests/test_keras_callback.py | 2 +- python/ray/train/tensorflow/tensorflow_checkpoint.py | 8 ++++---- python/ray/train/tests/test_tensorflow_checkpoint.py | 2 +- python/ray/train/tests/test_tensorflow_predictor.py | 4 ++-- python/ray/train/tests/test_tensorflow_trainer.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/ray/air/tests/test_keras_callback.py b/python/ray/air/tests/test_keras_callback.py index 3e4113ef40f2..2708f3c15354 100644 --- a/python/ray/air/tests/test_keras_callback.py +++ b/python/ray/air/tests/test_keras_callback.py @@ -197,7 +197,7 @@ def test_keras_callback_e2e(): assert checkpoint._flavor == TensorflowCheckpoint.Flavor.MODEL_WEIGHTS predictor = TensorflowPredictor.from_checkpoint( - checkpoint, model_definition=build_model + checkpoint, model=build_model ) items = np.random.uniform(0, 1, size=(10, 1)) diff --git a/python/ray/train/tensorflow/tensorflow_checkpoint.py b/python/ray/train/tensorflow/tensorflow_checkpoint.py index a9301f2b61a3..10ea4770b260 100644 --- a/python/ray/train/tensorflow/tensorflow_checkpoint.py +++ b/python/ray/train/tensorflow/tensorflow_checkpoint.py @@ -223,13 +223,13 @@ def get_model( Returns: The Tensorflow Keras model stored in the checkpoint. """ + # TODO: Remove `model_definition` in 2.6. if model_definition is not None: - warnings.warn( + raise DeprecationWarning( "The `model_definition` parameter is deprecated. Use the `model` " - "parameter instead.", - DeprecationWarning, + "parameter instead." ) - model = model_definition + if model is not None and self._flavor is not self.Flavor.MODEL_WEIGHTS: warnings.warn( "TensorflowCheckpoint was created from " diff --git a/python/ray/train/tests/test_tensorflow_checkpoint.py b/python/ray/train/tests/test_tensorflow_checkpoint.py index 3bcf9a064b0a..63202a7dbe02 100644 --- a/python/ray/train/tests/test_tensorflow_checkpoint.py +++ b/python/ray/train/tests/test_tensorflow_checkpoint.py @@ -41,7 +41,7 @@ def test_model_definition_raises_deprecation_warning(): model = get_model() checkpoint = TensorflowCheckpoint.from_model(model) with pytest.deprecated_call(): - checkpoint.get_model(model_definition=get_model) + checkpoint.get_model(model=get_model) def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool: diff --git a/python/ray/train/tests/test_tensorflow_predictor.py b/python/ray/train/tests/test_tensorflow_predictor.py index 33462d93980d..26a9085b18fb 100644 --- a/python/ray/train/tests/test_tensorflow_predictor.py +++ b/python/ray/train/tests/test_tensorflow_predictor.py @@ -88,7 +88,7 @@ def test_init(): predictor = TensorflowPredictor(model=build_model(), preprocessor=preprocessor) checkpoint_predictor = TensorflowPredictor.from_checkpoint( - checkpoint, model_definition=build_raw_model + checkpoint, model=build_raw_model ) assert checkpoint_predictor._model.get_weights() == predictor._model.get_weights() @@ -241,7 +241,7 @@ def test_tensorflow_predictor_no_training(use_gpu): model = build_model() checkpoint = TensorflowCheckpoint.from_model(model) batch_predictor = BatchPredictor.from_checkpoint( - checkpoint, TensorflowPredictor, model_definition=build_model, use_gpu=use_gpu + checkpoint, TensorflowPredictor, model=build_model, use_gpu=use_gpu ) predict_dataset = ray.data.range(3) predictions = batch_predictor.predict(predict_dataset) diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index 8fa41a2b805d..06950866eae5 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -84,7 +84,7 @@ def train_func(): assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) batch_predictor = BatchPredictor.from_checkpoint( - result.checkpoint, TensorflowPredictor, model_definition=build_model + result.checkpoint, TensorflowPredictor, model=build_model ) predict_dataset = ray.data.range(3) From 8447018675368f4674a29aa88e97ea1d919ffed2 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 16:43:49 -0700 Subject: [PATCH 06/11] Format `test_keras_callback.py` Signed-off-by: Balaji Veeramani --- python/ray/air/tests/test_keras_callback.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/ray/air/tests/test_keras_callback.py b/python/ray/air/tests/test_keras_callback.py index 2708f3c15354..09fb5d572593 100644 --- a/python/ray/air/tests/test_keras_callback.py +++ b/python/ray/air/tests/test_keras_callback.py @@ -196,9 +196,7 @@ def test_keras_callback_e2e(): assert isinstance(checkpoint, TensorflowCheckpoint) assert checkpoint._flavor == TensorflowCheckpoint.Flavor.MODEL_WEIGHTS - predictor = TensorflowPredictor.from_checkpoint( - checkpoint, model=build_model - ) + predictor = TensorflowPredictor.from_checkpoint(checkpoint, model=build_model) items = np.random.uniform(0, 1, size=(10, 1)) predictor.predict(data=items) From 78d0d0d4fd186bc56a24dd7946c22f86cf9cc544 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 18:35:47 -0700 Subject: [PATCH 07/11] Fix tests Signed-off-by: Balaji Veeramani --- python/ray/air/tests/test_keras_callback.py | 4 +++- python/ray/train/tests/test_tensorflow_predictor.py | 4 ++-- python/ray/train/tests/test_tensorflow_trainer.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/air/tests/test_keras_callback.py b/python/ray/air/tests/test_keras_callback.py index 09fb5d572593..3e4113ef40f2 100644 --- a/python/ray/air/tests/test_keras_callback.py +++ b/python/ray/air/tests/test_keras_callback.py @@ -196,7 +196,9 @@ def test_keras_callback_e2e(): assert isinstance(checkpoint, TensorflowCheckpoint) assert checkpoint._flavor == TensorflowCheckpoint.Flavor.MODEL_WEIGHTS - predictor = TensorflowPredictor.from_checkpoint(checkpoint, model=build_model) + predictor = TensorflowPredictor.from_checkpoint( + checkpoint, model_definition=build_model + ) items = np.random.uniform(0, 1, size=(10, 1)) predictor.predict(data=items) diff --git a/python/ray/train/tests/test_tensorflow_predictor.py b/python/ray/train/tests/test_tensorflow_predictor.py index 26a9085b18fb..33462d93980d 100644 --- a/python/ray/train/tests/test_tensorflow_predictor.py +++ b/python/ray/train/tests/test_tensorflow_predictor.py @@ -88,7 +88,7 @@ def test_init(): predictor = TensorflowPredictor(model=build_model(), preprocessor=preprocessor) checkpoint_predictor = TensorflowPredictor.from_checkpoint( - checkpoint, model=build_raw_model + checkpoint, model_definition=build_raw_model ) assert checkpoint_predictor._model.get_weights() == predictor._model.get_weights() @@ -241,7 +241,7 @@ def test_tensorflow_predictor_no_training(use_gpu): model = build_model() checkpoint = TensorflowCheckpoint.from_model(model) batch_predictor = BatchPredictor.from_checkpoint( - checkpoint, TensorflowPredictor, model=build_model, use_gpu=use_gpu + checkpoint, TensorflowPredictor, model_definition=build_model, use_gpu=use_gpu ) predict_dataset = ray.data.range(3) predictions = batch_predictor.predict(predict_dataset) diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index 06950866eae5..8fa41a2b805d 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -84,7 +84,7 @@ def train_func(): assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) batch_predictor = BatchPredictor.from_checkpoint( - result.checkpoint, TensorflowPredictor, model=build_model + result.checkpoint, TensorflowPredictor, model_definition=build_model ) predict_dataset = ray.data.range(3) From d583e485c02486d1eaee91c0376425ad7418e391 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 18:36:21 -0700 Subject: [PATCH 08/11] Update test_tensorflow_checkpoint.py Signed-off-by: Balaji Veeramani --- python/ray/train/tests/test_tensorflow_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/train/tests/test_tensorflow_checkpoint.py b/python/ray/train/tests/test_tensorflow_checkpoint.py index 63202a7dbe02..3bcf9a064b0a 100644 --- a/python/ray/train/tests/test_tensorflow_checkpoint.py +++ b/python/ray/train/tests/test_tensorflow_checkpoint.py @@ -41,7 +41,7 @@ def test_model_definition_raises_deprecation_warning(): model = get_model() checkpoint = TensorflowCheckpoint.from_model(model) with pytest.deprecated_call(): - checkpoint.get_model(model=get_model) + checkpoint.get_model(model_definition=get_model) def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool: From 20dcaa9ff2a19efea0605cb72d403a47070432de Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 18:38:40 -0700 Subject: [PATCH 09/11] Create temp.py Signed-off-by: Balaji Veeramani --- temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 temp.py diff --git a/temp.py b/temp.py new file mode 100644 index 000000000000..e69de29bb2d1 From 224aa51469793af42ded959b28789d68ab2832a4 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 18:38:45 -0700 Subject: [PATCH 10/11] Delete temp.py Signed-off-by: Balaji Veeramani --- temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 temp.py diff --git a/temp.py b/temp.py deleted file mode 100644 index e69de29bb2d1..000000000000 From fb372c87adb8adcd227aa1ec0dec1ea0a6ab078a Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Mar 2023 19:29:19 -0700 Subject: [PATCH 11/11] Update test_tensorflow_checkpoint.py Signed-off-by: Balaji Veeramani --- python/ray/train/tests/test_tensorflow_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/train/tests/test_tensorflow_checkpoint.py b/python/ray/train/tests/test_tensorflow_checkpoint.py index 3bcf9a064b0a..8514a0448d9d 100644 --- a/python/ray/train/tests/test_tensorflow_checkpoint.py +++ b/python/ray/train/tests/test_tensorflow_checkpoint.py @@ -40,7 +40,7 @@ def get_model(): def test_model_definition_raises_deprecation_warning(): model = get_model() checkpoint = TensorflowCheckpoint.from_model(model) - with pytest.deprecated_call(): + with pytest.raises(DeprecationWarning): checkpoint.get_model(model_definition=get_model)