diff --git a/model-archiver/model_archiver/__init__.py b/model-archiver/model_archiver/__init__.py index a459548609..ddac7739f1 100644 --- a/model-archiver/model_archiver/__init__.py +++ b/model-archiver/model_archiver/__init__.py @@ -1,5 +1,3 @@ - - """ This module does the following: Exports the model folder to generate a Model Archive file out of it in .mar format @@ -7,3 +5,6 @@ from . import version __version__ = version.__version__ + +from .model_archiver import ModelArchiver +from .model_archiver_config import ModelArchiverConfig diff --git a/model-archiver/model_archiver/arg_parser.py b/model-archiver/model_archiver/arg_parser.py index 17a54e6380..f431ea6402 100644 --- a/model-archiver/model_archiver/arg_parser.py +++ b/model-archiver/model_archiver/arg_parser.py @@ -6,6 +6,7 @@ import os from .manifest_components.manifest import RuntimeType +from .model_archiver_config import ModelArchiverConfig # noinspection PyTypeChecker @@ -154,4 +155,4 @@ def export_model_args_parser(): help="Path to a yaml file containing model configuration eg. batch_size.", ) - return parser_export + return ModelArchiverConfig.from_args(parser_export.parse_args()) diff --git a/model-archiver/model_archiver/model_archiver.py b/model-archiver/model_archiver/model_archiver.py new file mode 100644 index 0000000000..9dd72b0861 --- /dev/null +++ b/model-archiver/model_archiver/model_archiver.py @@ -0,0 +1,17 @@ +""" +Helper class to generate a model archive file +""" + +from model_archiver.model_archiver_config import ModelArchiverConfig +from model_archiver.model_packaging import generate_model_archive + + +class ModelArchiver: + @staticmethod + def generate_model_archive(config: ModelArchiverConfig) -> None: + """ + Generate a model archive file + :param config: Model Archiver Config object + :return: + """ + generate_model_archive(config) diff --git a/model-archiver/model_archiver/model_archiver_config.py b/model-archiver/model_archiver/model_archiver_config.py new file mode 100644 index 0000000000..5cbc4fbbc1 --- /dev/null +++ b/model-archiver/model_archiver/model_archiver_config.py @@ -0,0 +1,28 @@ +import os +from argparse import Namespace +from dataclasses import dataclass, fields +from typing import Literal, Optional + +from model_archiver.manifest_components.manifest import RuntimeType + + +@dataclass +class ModelArchiverConfig: + model_name: str + handler: str + version: str + serialized_file: Optional[str] = None + model_file: Optional[str] = None + extra_files: Optional[str] = None + runtime: str = RuntimeType.PYTHON.value + export_path: str = os.getcwd() + archive_format: Literal["default", "tgz", "no-archive"] = "default" + force: bool = False + requirements_file: Optional[str] = None + config_file: Optional[str] = None + + @classmethod + def from_args(cls, args: Namespace) -> "ModelArchiverConfig": + params = {field.name: getattr(args, field.name) for field in fields(cls)} + config = cls(**params) + return config diff --git a/model-archiver/model_archiver/model_packaging.py b/model-archiver/model_archiver/model_packaging.py index 3304f6a4f1..eeccea885c 100644 --- a/model-archiver/model_archiver/model_packaging.py +++ b/model-archiver/model_archiver/model_packaging.py @@ -4,31 +4,32 @@ import logging import shutil -import sys +from typing import Optional from model_archiver.arg_parser import ArgParser +from model_archiver.model_archiver_config import ModelArchiverConfig from model_archiver.model_archiver_error import ModelArchiverError from model_archiver.model_packaging_utils import ModelExportUtils -def package_model(args, manifest): +def package_model(config: ModelArchiverConfig, manifest: str): """ Internal helper for the exporting model command line interface. """ - model_file = args.model_file - serialized_file = args.serialized_file - model_name = args.model_name - handler = args.handler - extra_files = args.extra_files - export_file_path = args.export_path - requirements_file = args.requirements_file - config_file = args.config_file + model_file = config.model_file + serialized_file = config.serialized_file + model_name = config.model_name + handler = config.handler + extra_files = config.extra_files + export_file_path = config.export_path + requirements_file = config.requirements_file + config_file = config.config_file try: ModelExportUtils.validate_inputs(model_name, export_file_path) # Step 1 : Check if .mar already exists with the given model name export_file_path = ModelExportUtils.check_mar_already_exists( - model_name, export_file_path, args.force, args.archive_format + model_name, export_file_path, config.force, config.archive_format ) # Step 2 : Copy all artifacts to temp directory @@ -45,7 +46,7 @@ def package_model(args, manifest): # Step 2 : Zip 'em all up ModelExportUtils.archive( - export_file_path, model_name, model_path, manifest, args.archive_format + export_file_path, model_name, model_path, manifest, config.archive_format ) shutil.rmtree(model_path) logging.info( @@ -53,19 +54,20 @@ def package_model(args, manifest): ) except ModelArchiverError as e: logging.error(e) - sys.exit(1) + raise e -def generate_model_archive(): +def generate_model_archive(config: Optional[ModelArchiverConfig] = None): """ Generate a model archive file :return: """ logging.basicConfig(format="%(levelname)s - %(message)s") - args = ArgParser.export_model_args_parser().parse_args() - manifest = ModelExportUtils.generate_manifest_json(args) - package_model(args, manifest=manifest) + if config is None: + config = ArgParser.export_model_args_parser() + manifest = ModelExportUtils.generate_manifest_json(config) + package_model(config, manifest=manifest) if __name__ == "__main__": diff --git a/model-archiver/model_archiver/model_packaging_utils.py b/model-archiver/model_archiver/model_packaging_utils.py index c46e30888e..01ac568a95 100644 --- a/model-archiver/model_archiver/model_packaging_utils.py +++ b/model-archiver/model_archiver/model_packaging_utils.py @@ -15,6 +15,7 @@ from .manifest_components.manifest import Manifest from .manifest_components.model import Model +from .model_archiver_config import ModelArchiverConfig from .model_archiver_error import ModelArchiverError archiving_options = { @@ -107,29 +108,29 @@ def find_unique(files, suffix): ) @staticmethod - def generate_model(modelargs): + def generate_model(modelcfg: ModelArchiverConfig): model = Model( - model_name=modelargs.model_name, - serialized_file=modelargs.serialized_file, - model_file=modelargs.model_file, - handler=modelargs.handler, - model_version=modelargs.version, - requirements_file=modelargs.requirements_file, - config_file=modelargs.config_file, + model_name=modelcfg.model_name, + serialized_file=modelcfg.serialized_file, + model_file=modelcfg.model_file, + handler=modelcfg.handler, + model_version=modelcfg.version, + requirements_file=modelcfg.requirements_file, + config_file=modelcfg.config_file, ) return model @staticmethod - def generate_manifest_json(args): + def generate_manifest_json(config: ModelArchiverConfig) -> str: """ Function to generate manifest as a json string from the inputs provided by the user in the command line :param args: :return: """ - model = ModelExportUtils.generate_model(args) + model = ModelExportUtils.generate_model(config) - manifest = Manifest(runtime=args.runtime, model=model) + manifest = Manifest(runtime=config.runtime, model=model) return str(manifest) diff --git a/model-archiver/model_archiver/tests/integ_tests/test_integration_model_archiver.py b/model-archiver/model_archiver/tests/integ_tests/test_integration_model_archiver.py index 089bed7b66..9bd07b5624 100644 --- a/model-archiver/model_archiver/tests/integ_tests/test_integration_model_archiver.py +++ b/model-archiver/model_archiver/tests/integ_tests/test_integration_model_archiver.py @@ -40,12 +40,11 @@ def delete_file_path(path): pass -def run_test(test, args, mocker): - m = mocker.patch( +def run_test(test, config, mocker): + mocker.patch( "model_archiver.model_packaging.ArgParser.export_model_args_parser", + return_value=config, ) - m.return_value.parse_args.return_value = args - mocker.patch("sys.exit", side_effect=Exception()) from model_archiver.model_packaging import generate_model_archive it = test.get("iterations", 1) @@ -179,7 +178,9 @@ def build_namespace(test): args = Namespace(**{k.replace("-", "_"): test[k] for k in keys}) - return args + config = model_archiver.ModelArchiverConfig.from_args(args) + + return config def make_paths_absolute(test, keys): diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_archiver.py b/model-archiver/model_archiver/tests/unit_tests/test_model_archiver.py new file mode 100644 index 0000000000..45d9e2984c --- /dev/null +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_archiver.py @@ -0,0 +1,68 @@ +from argparse import Namespace +from collections import namedtuple + +import pytest +from model_archiver import ModelArchiver, ModelArchiverConfig +from model_archiver.manifest_components.manifest import RuntimeType + + +# noinspection PyClassHasNoInit +class TestModelArchiver: + model_name = "my-model" + model_file = "my-model/" + serialized_file = "my-model/" + handler = "a.py::my-awesome-func" + export_path = "/Users/dummyUser/" + version = "1.0" + requirements_file = "requirements.txt" + config_file = None + + config = ModelArchiverConfig( + model_name=model_name, + handler=handler, + runtime=RuntimeType.PYTHON.value, + model_file=model_file, + serialized_file=serialized_file, + extra_files=None, + export_path=export_path, + force=False, + archive_format="default", + version=version, + requirements_file=requirements_file, + config_file=None, + ) + + @pytest.fixture() + def patches(self, mocker): + Patches = namedtuple("Patches", ["arg_parse", "export_utils", "export_method"]) + patches = Patches( + mocker.patch("model_archiver.arg_parser.ArgParser"), + mocker.patch("model_archiver.model_packaging.ModelExportUtils"), + mocker.patch("model_archiver.model_packaging.package_model"), + ) + mocker.patch("shutil.rmtree") + + return patches + + def test_gen_model_archive(self, patches): + ModelArchiver.generate_model_archive(self.config) + patches.export_method.assert_called() + + def test_model_archiver_config_from_args(self): + args = Namespace( + model_name=self.model_name, + handler=self.handler, + runtime=RuntimeType.PYTHON.value, + model_file=self.model_file, + serialized_file=self.serialized_file, + extra_files=None, + export_path=self.export_path, + force=False, + archive_format="default", + version=self.version, + requirements_file=self.requirements_file, + config_file=None, + ) + config = ModelArchiverConfig.from_args(args) + + assert config == self.config diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py index da49fb1fc1..e628f29e25 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py @@ -1,6 +1,7 @@ from collections import namedtuple import pytest +from model_archiver import ModelArchiverConfig from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_packaging import generate_model_archive, package_model from model_archiver.model_packaging_utils import ModelExportUtils @@ -8,13 +9,6 @@ # noinspection PyClassHasNoInit class TestModelPackaging: - class Namespace: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - def update(self, **kwargs): - self.__dict__.update(kwargs) - model_name = "my-model" model_file = "my-model/" serialized_file = "my-model/" @@ -23,9 +17,8 @@ def update(self, **kwargs): version = "1.0" requirements_file = "requirements.txt" config_file = None - source_vocab = None - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value, @@ -35,9 +28,7 @@ def update(self, **kwargs): export_path=export_path, force=False, archive_format="default", - convert=False, version=version, - source_vocab=source_vocab, requirements_file=requirements_file, config_file=None, ) @@ -55,7 +46,7 @@ def patches(self, mocker): return patches def test_gen_model_archive(self, patches): - patches.arg_parse.export_model_args_parser.parse_args.return_value = self.args + patches.arg_parse.export_model_args_parser.parse_args.return_value = self.config generate_model_archive() patches.export_method.assert_called() @@ -67,25 +58,24 @@ def test_export_model_method(self, patches): ) patches.export_utils.zip.return_value = None - package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) + package_model(self.config, ModelExportUtils.generate_manifest_json(self.config)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() def test_export_model_method_tar(self, patches): - self.args.update(archive_format="tar") + self.config.archive_format = "tgz" patches.export_utils.check_mar_already_exists.return_value = "/Users/dummyUser/" patches.export_utils.check_custom_model_types.return_value = ( "/Users/dummyUser", ["a.txt", "b.txt"], ) - patches.export_utils.zip.return_value = None - package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) + package_model(self.config, ModelExportUtils.generate_manifest_json(self.config)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() def test_export_model_method_noarchive(self, patches): - self.args.update(archive_format="no-archive") + self.config.archive_format = "no-archive" patches.export_utils.check_mar_already_exists.return_value = "/Users/dummyUser/" patches.export_utils.check_custom_model_types.return_value = ( "/Users/dummyUser", @@ -93,6 +83,6 @@ def test_export_model_method_noarchive(self, patches): ) patches.export_utils.zip.return_value = None - package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) + package_model(self.config, ModelExportUtils.generate_manifest_json(self.config)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py index cc453be94a..b020ac3a2c 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest +from model_archiver import ModelArchiverConfig from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_archiver_error import ModelArchiverError from model_archiver.model_packaging_utils import ModelExportUtils @@ -142,10 +143,6 @@ def test_clean_call(self, patches): # noinspection PyClassHasNoInit class TestGenerateManifestProps: - class Namespace: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - model_name = "my-model" handler = "a.py::my-awesome-func" serialized_file = "model.pt" @@ -153,7 +150,7 @@ def __init__(self, **kwargs): version = "1.0" requirements_file = "requirements.txt" - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value, @@ -165,12 +162,12 @@ def __init__(self, **kwargs): ) def test_model(self): - mod = ModelExportUtils.generate_model(self.args) + mod = ModelExportUtils.generate_model(self.config) assert mod.model_name == self.model_name assert mod.handler == self.handler def test_manifest_json(self): - manifest = ModelExportUtils.generate_manifest_json(self.args) + manifest = ModelExportUtils.generate_manifest_json(self.config) manifest_json = json.loads(manifest) assert manifest_json["runtime"] == RuntimeType.PYTHON.value assert "model" in manifest_json diff --git a/test/pytest/test_auto_recover.py b/test/pytest/test_auto_recover.py index 87bca76c4a..15bf726f39 100644 --- a/test/pytest/test_auto_recover.py +++ b/test/pytest/test_auto_recover.py @@ -1,13 +1,13 @@ import json import platform import shutil -from argparse import Namespace from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest import requests import test_utils +from model_archiver import ModelArchiverConfig CURR_FILE_PATH = Path(__file__).parent REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent @@ -97,7 +97,7 @@ def create_mar_file(work_dir, model_archiver, model_name): handler_py_file = work_dir / "handler.py" handler_py_file.write_text(HANDLER_PY) - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=None, @@ -112,9 +112,7 @@ def create_mar_file(work_dir, model_archiver, model_name): config_file=None, ) - mock = MagicMock() - mock.parse_args = MagicMock(return_value=args) - with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): model_archiver.generate_model_archive() assert mar_file_path.exists() diff --git a/test/pytest/test_continuous_batching.py b/test/pytest/test_continuous_batching.py index 2d6974510f..fd926b2969 100644 --- a/test/pytest/test_continuous_batching.py +++ b/test/pytest/test_continuous_batching.py @@ -1,15 +1,15 @@ import json import shutil -from argparse import Namespace from pathlib import Path from queue import Empty -from unittest.mock import MagicMock, patch +from unittest.mock import patch from zipfile import ZIP_STORED, ZipFile import pytest import requests import test_utils import torch +from model_archiver import ModelArchiverConfig from test_data.streaming.stream_handler import StreamingHandler from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext @@ -31,7 +31,7 @@ def work_dir(tmp_path_factory, model_name): def create_mar_file(work_dir, model_archiver, model_name): mar_file_path = Path(work_dir).joinpath(model_name + ".mar") - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", model_file=CURR_FILE_PATH.joinpath( @@ -52,9 +52,7 @@ def create_mar_file(work_dir, model_archiver, model_name): extra_files=None, ) - mock = MagicMock() - mock.parse_args = MagicMock(return_value=args) - with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs with patch( "model_archiver.model_packaging_utils.zipfile.ZipFile", diff --git a/test/pytest/test_example_micro_batching.py b/test/pytest/test_example_micro_batching.py index 60a3c9c3d0..eb46ac8d22 100644 --- a/test/pytest/test_example_micro_batching.py +++ b/test/pytest/test_example_micro_batching.py @@ -2,15 +2,15 @@ import json import random import shutil -from argparse import Namespace from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch from zipfile import ZIP_STORED, ZipFile import pytest import requests import test_utils import yaml +from model_archiver import ModelArchiverConfig from torchvision.models.resnet import ResNet18_Weights from ts.torch_handler.unit_tests.test_utils.model_dir import download_model @@ -113,7 +113,7 @@ def create_mar_file(work_dir, serialized_file, model_archiver, model_name, reque extra_files = [name_file] - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=str(serialized_file), @@ -132,9 +132,7 @@ def create_mar_file(work_dir, serialized_file, model_archiver, model_name, reque config_file=config_file, ) - mock = MagicMock() - mock.parse_args = MagicMock(return_value=args) - with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs with patch( "model_archiver.model_packaging_utils.zipfile.ZipFile", diff --git a/test/pytest/test_example_near_real_time_video.py b/test/pytest/test_example_near_real_time_video.py index 18548a7e0b..549362ac36 100644 --- a/test/pytest/test_example_near_real_time_video.py +++ b/test/pytest/test_example_near_real_time_video.py @@ -4,13 +4,13 @@ import json import os import shutil -from argparse import Namespace from pathlib import Path import pytest import requests import test_utils import torch +from model_archiver import ModelArchiverConfig from ts.torch_handler.image_classifier import ImageClassifier from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext @@ -50,7 +50,7 @@ def create_mar_file(work_dir, session_mocker, model_archiver): mar_file_path = Path(work_dir).joinpath(model_name + ".mar") - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=os.path.join(REPO_ROOT_DIR, MODEL_PTH_FILE), @@ -67,10 +67,8 @@ def create_mar_file(work_dir, session_mocker, model_archiver): config_file=None, ) - mock = session_mocker.MagicMock() - mock.parse_args = session_mocker.MagicMock(return_value=args) session_mocker.patch( - "archiver.ArgParser.export_model_args_parser", return_value=mock + "archiver.ArgParser.export_model_args_parser", return_value=config ) # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs diff --git a/test/pytest/test_example_scriptable_tokenzier.py b/test/pytest/test_example_scriptable_tokenzier.py index 8c1d617270..05597b7c96 100644 --- a/test/pytest/test_example_scriptable_tokenzier.py +++ b/test/pytest/test_example_scriptable_tokenzier.py @@ -11,6 +11,7 @@ import requests import test_utils import torch +from model_archiver import ModelArchiverConfig from test_utils import REPO_ROOT from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext @@ -148,7 +149,7 @@ def create_mar_file(work_dir, session_mocker, jit_file_path, model_archiver): mar_file_path = os.path.join(work_dir, model_name + ".mar") - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=jit_file_path, @@ -163,10 +164,8 @@ def create_mar_file(work_dir, session_mocker, jit_file_path, model_archiver): config_file=None, ) - mock = session_mocker.MagicMock() - mock.parse_args = session_mocker.MagicMock(return_value=args) session_mocker.patch( - "archiver.ArgParser.export_model_args_parser", return_value=mock + "archiver.ArgParser.export_model_args_parser", return_value=config ) # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs diff --git a/test/pytest/test_example_torchrec_dlrm.py b/test/pytest/test_example_torchrec_dlrm.py index e4fef7e240..2a5b203ae2 100644 --- a/test/pytest/test_example_torchrec_dlrm.py +++ b/test/pytest/test_example_torchrec_dlrm.py @@ -4,13 +4,13 @@ import json import shutil import sys -from argparse import Namespace from pathlib import Path import pytest import requests import test_utils import torch +from model_archiver import ModelArchiverConfig from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext @@ -90,7 +90,7 @@ def create_mar_file(work_dir, session_mocker, serialized_file, model_archiver): mar_file_path = Path(work_dir).joinpath(model_name + ".mar") - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=str(serialized_file), @@ -107,10 +107,8 @@ def create_mar_file(work_dir, session_mocker, serialized_file, model_archiver): config_file=None, ) - mock = session_mocker.MagicMock() - mock.parse_args = session_mocker.MagicMock(return_value=args) session_mocker.patch( - "archiver.ArgParser.export_model_args_parser", return_value=mock + "archiver.ArgParser.export_model_args_parser", return_value=config ) # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs diff --git a/test/pytest/test_parallelism.py b/test/pytest/test_parallelism.py index 04183ec01f..58c2a0ddaf 100644 --- a/test/pytest/test_parallelism.py +++ b/test/pytest/test_parallelism.py @@ -1,13 +1,13 @@ import json import platform import shutil -from argparse import Namespace from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest import requests import test_utils +from model_archiver import ModelArchiverConfig CURR_FILE_PATH = Path(__file__).parent REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent @@ -77,7 +77,7 @@ def create_mar_file(work_dir, model_archiver, model_name): handler_py_file = work_dir / "handler.py" handler_py_file.write_text(HANDLER_PY) - args = Namespace( + config = ModelArchiverConfig( model_name=model_name, version="1.0", serialized_file=None, @@ -92,9 +92,7 @@ def create_mar_file(work_dir, model_archiver, model_name): config_file=model_config_yaml_file.as_posix(), ) - mock = MagicMock() - mock.parse_args = MagicMock(return_value=args) - with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): model_archiver.generate_model_archive() assert mar_file_path.exists()