diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index ed68171cc6f9c..c13843f816f16 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -54,8 +54,10 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: + # `save` will destructively access any external data + copied_model = copy.deepcopy(accessor._GLOBAL_ACCESSOR.model) onnx.save( - accessor._GLOBAL_ACCESSOR.model, + copied_model, self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 5c63be92d2b2f..0866d4a411e29 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1159,7 +1159,8 @@ def test_generate_artifacts_external_data_one_file(): assert os.path.exists(os.path.join(temp_dir, "checkpoint")) -def test_generate_artifacts_external_data_separate_files(): +@pytest.mark.parametrize("loss", [loss_t for loss_t in artifacts.LossType]) +def test_generate_artifacts_external_data_separate_files(loss): with tempfile.TemporaryDirectory() as temp_dir: _, simple_net = _get_models("cpu", 32, 28, 10, 10) @@ -1176,7 +1177,7 @@ def test_generate_artifacts_external_data_separate_files(): artifacts.generate_artifacts( os.path.join(temp_dir, "simple_net.onnx"), requires_grad=requires_grad_params, - loss=artifacts.LossType.CrossEntropyLoss, + loss=loss, optimizer=artifacts.OptimType.AdamW, artifact_directory=temp_dir, )