From a5e85a950c2fab5729c46e7362a60765caa4b999 Mon Sep 17 00:00:00 2001 From: Justin Beavers Date: Tue, 15 Oct 2024 17:47:16 -0600 Subject: [PATCH] Fix training artifacts for 2GB+ models and `MSELoss` (#22414) --- orttraining/orttraining/python/training/onnxblock/blocks.py | 4 +++- .../test/python/orttraining_test_ort_apis_onnxblock.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index ed68171cc6f9c..c13843f816f16 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -54,8 +54,10 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: + # `save` will destructively access any external data + copied_model = copy.deepcopy(accessor._GLOBAL_ACCESSOR.model) onnx.save( - accessor._GLOBAL_ACCESSOR.model, + copied_model, self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 5c63be92d2b2f..0866d4a411e29 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1159,7 +1159,8 @@ def test_generate_artifacts_external_data_one_file(): assert os.path.exists(os.path.join(temp_dir, "checkpoint")) -def test_generate_artifacts_external_data_separate_files(): +@pytest.mark.parametrize("loss", [loss_t for loss_t in artifacts.LossType]) +def test_generate_artifacts_external_data_separate_files(loss): with tempfile.TemporaryDirectory() as temp_dir: _, simple_net = _get_models("cpu", 32, 28, 10, 10) @@ -1176,7 +1177,7 @@ def test_generate_artifacts_external_data_separate_files(): artifacts.generate_artifacts( os.path.join(temp_dir, "simple_net.onnx"), requires_grad=requires_grad_params, - loss=artifacts.LossType.CrossEntropyLoss, + loss=loss, optimizer=artifacts.OptimType.AdamW, artifact_directory=temp_dir, )