From ba8f96a6e68ca7f63b55d72a21aad364334e4d8e Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Fri, 9 Feb 2024 12:06:51 -0800 Subject: [PATCH] Enable model custom dependency installation using virtual environment (#2910) * Explicitly set default dependency upgrade strategy to only-if-needed * Add support for venv creation per model * Create virtual env only when requirements need to be installed * Format logger and exception messages * Enable per model useVenv configuration option * Add integration tests and documentation * Fix integraiton test failures * Revert "Fix integraiton test failures" This reverts commit 6c4a279c1dd730acc7e797b4d4df993962acdbc8. * Update integration test teardown functionality * Refactor implementaiton to check for useVenv in Model.java * Fix dependencyPath logic * Refactor isUseVenv and integration tests * Update documentation --------- Co-authored-by: Ankith Gunapal --- .../serve/archive/model/ModelConfig.java | 22 +- .../serve/util/messages/EnvironmentUtils.java | 18 +- .../java/org/pytorch/serve/wlm/Model.java | 10 + .../org/pytorch/serve/wlm/ModelManager.java | 189 ++++++++--- .../org/pytorch/serve/ModelServerTest.java | 2 +- .../java/org/pytorch/serve/WorkflowTest.java | 2 +- model-archiver/README.md | 13 +- .../custom_dependencies/config.properties | 1 + .../mnist_custom_dependencies_handler.py | 20 ++ .../custom_dependencies/model_config.yaml | 1 + .../custom_dependencies/requirements.txt | 1 + test/pytest/test_model_custom_dependencies.py | 320 ++++++++++++++++++ 12 files changed, 543 insertions(+), 56 deletions(-) create mode 100644 test/pytest/test_data/custom_dependencies/config.properties create mode 100644 test/pytest/test_data/custom_dependencies/mnist_custom_dependencies_handler.py create mode 100644 test/pytest/test_data/custom_dependencies/model_config.yaml create mode 100644 test/pytest/test_data/custom_dependencies/requirements.txt create mode 100644 test/pytest/test_model_custom_dependencies.py diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index adcf42afd8..fd67561c0e 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -65,9 +65,14 @@ public class ModelConfig { private int maxSequenceJobQueueSize = 1; /** the max number of sequences can be accepted. The default value is 1. */ private int maxNumSequence = 1; - /** continuousBatching is a flag to enable continuous batching. */ private boolean continuousBatching; + /** + * Create python virtual environment when using python backend to install model dependencies (if + * enabled globally using configuration install_py_dep_per_model=true) and run workers for model + * loading and inference. + */ + private boolean useVenv; public static ModelConfig build(Map yamlMap) { ModelConfig modelConfig = new ModelConfig(); @@ -207,6 +212,13 @@ public static ModelConfig build(Map yamlMap) { v); } break; + case "useVenv": + if (v instanceof Boolean) { + modelConfig.setUseVenv((boolean) v); + } else { + logger.warn("Invalid useVenv: {}, should be true or false", v); + } + break; default: break; } @@ -379,6 +391,14 @@ public void setMaxNumSequence(int maxNumSequence) { this.maxNumSequence = Math.max(1, maxNumSequence); } + public boolean getUseVenv() { + return useVenv; + } + + public void setUseVenv(boolean useVenv) { + this.useVenv = useVenv; + } + public enum ParallelType { NONE(""), PP("pp"), diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java index 7b87b8e278..3b5a6779b2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java @@ -3,6 +3,8 @@ import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @@ -74,15 +76,29 @@ public static String[] getEnvString(String cwd, String modelPath, String handler public static String getPythonRunTime(Model model) { String pythonRuntime; - Manifest.RuntimeType runtime = model.getModelArchive().getManifest().getRuntime(); + Manifest.RuntimeType runtime = model.getRuntimeType(); if (runtime == Manifest.RuntimeType.PYTHON) { pythonRuntime = configManager.getPythonExecutable(); + Path pythonVenvRuntime = + Paths.get(getPythonVenvPath(model).toString(), "bin", "python"); + if (model.isUseVenv() && Files.exists(pythonVenvRuntime)) { + pythonRuntime = pythonVenvRuntime.toString(); + } } else { pythonRuntime = runtime.getValue(); } return pythonRuntime; } + public static File getPythonVenvPath(Model model) { + File modelDir = model.getModelDir(); + if (Files.isSymbolicLink(modelDir.toPath())) { + modelDir = modelDir.getParentFile(); + } + Path venvPath = Paths.get(modelDir.getAbsolutePath(), "venv").toAbsolutePath(); + return venvPath.toFile(); + } + public static String[] getCppEnvString(String libPath) { ArrayList envList = new ArrayList<>(); StringBuilder cppPath = new StringBuilder(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index f5a09d1b46..2d14e89b25 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -83,11 +83,13 @@ public class Model { private boolean useJobTicket; private AtomicInteger numJobTickets; private boolean continuousBatching; + private boolean useVenv; public Model(ModelArchive modelArchive, int queueSize) { this.modelArchive = modelArchive; if (modelArchive != null && modelArchive.getModelConfig() != null) { continuousBatching = modelArchive.getModelConfig().isContinuousBatching(); + useVenv = modelArchive.getModelConfig().getUseVenv(); if (modelArchive.getModelConfig().getParallelLevel() > 0 && modelArchive.getModelConfig().getParallelType() != ModelConfig.ParallelType.NONE) { @@ -636,6 +638,14 @@ public boolean isContinuousBatching() { return continuousBatching; } + public boolean isUseVenv() { + if (getRuntimeType() == Manifest.RuntimeType.PYTHON) { + return useVenv; + } else { + return false; + } + } + public boolean hasTensorParallel() { switch (this.parallelType) { case PP: diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 5258e38f51..7e0c64e496 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -100,6 +100,8 @@ public void registerAndUpdateModel(String modelName, JsonObject modelInfo) createVersionedModel(tempModel, versionId); + setupModelVenv(tempModel); + setupModelDependencies(tempModel); if (defaultVersion) { modelManager.setDefaultVersion(modelName, versionId); @@ -153,6 +155,8 @@ public ModelArchive registerModel( } } + setupModelVenv(tempModel); + setupModelDependencies(tempModel); logger.info("Model {} loaded.", tempModel.getModelName()); @@ -205,23 +209,119 @@ private ModelArchive createModelArchive( return archive; } + private void setupModelVenv(Model model) + throws IOException, InterruptedException, ModelException { + if (!model.isUseVenv()) { + return; + } + + File venvPath = EnvironmentUtils.getPythonVenvPath(model); + List commandParts = new ArrayList<>(); + commandParts.add(configManager.getPythonExecutable()); + commandParts.add("-m"); + commandParts.add("venv"); + commandParts.add("--clear"); + commandParts.add("--system-site-packages"); + commandParts.add(venvPath.toString()); + + ProcessBuilder processBuilder = new ProcessBuilder(commandParts); + + if (isValidDependencyPath(venvPath)) { + processBuilder.directory(venvPath.getParentFile()); + } else { + throw new ModelException( + "Invalid python venv path for model " + + model.getModelName() + + ": " + + venvPath.toString()); + } + Map environment = processBuilder.environment(); + String[] envp = + EnvironmentUtils.getEnvString( + configManager.getModelServerHome(), + model.getModelDir().getAbsolutePath(), + null); + for (String envVar : envp) { + String[] parts = envVar.split("=", 2); + if (parts.length == 2) { + environment.put(parts[0], parts[1]); + } + } + processBuilder.redirectErrorStream(true); + + Process process = processBuilder.start(); + + int exitCode = process.waitFor(); + String line; + StringBuilder outputString = new StringBuilder(); + BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream())); + while ((line = brdr.readLine()) != null) { + outputString.append(line + "\n"); + } + + if (exitCode == 0) { + logger.info( + "Created virtual environment for model {}: {}", + model.getModelName(), + venvPath.toString()); + } else { + logger.error( + "Virtual environment creation for model {} at {} failed:\n{}", + model.getModelName(), + venvPath.toString(), + outputString.toString()); + throw new ModelException( + "Virtual environment creation failed for model " + model.getModelName()); + } + } + private void setupModelDependencies(Model model) throws IOException, InterruptedException, ModelException { String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile(); - if (configManager.getInstallPyDepPerModel() && requirementsFile != null) { - Path requirementsFilePath = - Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile); + if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) { + return; + } + + String pythonRuntime = EnvironmentUtils.getPythonRunTime(model); + Path requirementsFilePath = + Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); + List commandParts = new ArrayList<>(); + ProcessBuilder processBuilder = new ProcessBuilder(); + + if (model.isUseVenv()) { + if (!isValidDependencyPath(Paths.get(pythonRuntime).toFile())) { + throw new ModelException( + "Invalid python venv runtime path for model " + + model.getModelName() + + ": " + + pythonRuntime); + } - String pythonRuntime = EnvironmentUtils.getPythonRunTime(model); + processBuilder.directory(EnvironmentUtils.getPythonVenvPath(model).getParentFile()); + commandParts.add(pythonRuntime); + commandParts.add("-m"); + commandParts.add("pip"); + commandParts.add("install"); + commandParts.add("-U"); + commandParts.add("--upgrade-strategy"); + commandParts.add("only-if-needed"); + commandParts.add("-r"); + commandParts.add(requirementsFilePath.toString()); + } else { File dependencyPath = model.getModelDir(); if (Files.isSymbolicLink(dependencyPath.toPath())) { dependencyPath = dependencyPath.getParentFile(); } + dependencyPath = dependencyPath.getAbsoluteFile(); + if (!isValidDependencyPath(dependencyPath)) { + throw new ModelException( + "Invalid 3rd party package installation path " + dependencyPath.toString()); + } - List commandParts = new ArrayList<>(); + processBuilder.directory(dependencyPath); commandParts.add(pythonRuntime); commandParts.add("-m"); @@ -229,57 +329,48 @@ private void setupModelDependencies(Model model) commandParts.add("install"); commandParts.add("-U"); commandParts.add("-t"); - commandParts.add(dependencyPath.getAbsolutePath()); + commandParts.add(dependencyPath.toString()); commandParts.add("-r"); commandParts.add(requirementsFilePath.toString()); + } - String[] envp = - EnvironmentUtils.getEnvString( - configManager.getModelServerHome(), - model.getModelDir().getAbsolutePath(), - null); - - ProcessBuilder processBuilder = new ProcessBuilder(commandParts); - if (isValidDependencyPath(dependencyPath)) { - processBuilder.directory(dependencyPath); - } else { - throw new ModelException( - "Invalid 3rd party package installation path " - + dependencyPath.getCanonicalPath()); + processBuilder.command(commandParts); + String[] envp = + EnvironmentUtils.getEnvString( + configManager.getModelServerHome(), + model.getModelDir().getAbsolutePath(), + null); + Map environment = processBuilder.environment(); + for (String envVar : envp) { + String[] parts = envVar.split("=", 2); + if (parts.length == 2) { + environment.put(parts[0], parts[1]); } + } + processBuilder.redirectErrorStream(true); - Map environment = processBuilder.environment(); - for (String envVar : envp) { - String[] parts = envVar.split("=", 2); - if (parts.length == 2) { - environment.put(parts[0], parts[1]); - } - } - Process process = processBuilder.start(); - int exitCode = process.waitFor(); - - if (exitCode != 0) { - - String line; - StringBuilder outputString = new StringBuilder(); - // process's stdout is InputStream for caller process - BufferedReader brdr = - new BufferedReader(new InputStreamReader(process.getInputStream())); - while ((line = brdr.readLine()) != null) { - outputString.append(line); - } - StringBuilder errorString = new StringBuilder(); - // process's stderr is ErrorStream for caller process - brdr = new BufferedReader(new InputStreamReader(process.getErrorStream())); - while ((line = brdr.readLine()) != null) { - errorString.append(line); - } + Process process = processBuilder.start(); - logger.error("Dependency installation stderr:\n" + errorString.toString()); + int exitCode = process.waitFor(); + String line; + StringBuilder outputString = new StringBuilder(); + BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream())); + while ((line = brdr.readLine()) != null) { + outputString.append(line + "\n"); + } - throw new ModelException( - "Custom pip package installation failed for " + model.getModelName()); - } + if (exitCode == 0) { + logger.info( + "Installed custom pip packages for model {}:\n{}", + model.getModelName(), + outputString.toString()); + } else { + logger.error( + "Custom pip package installation failed for model {}:\n{}", + model.getModelName(), + outputString.toString()); + throw new ModelException( + "Custom pip package installation failed for model " + model.getModelName()); } } diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index b28b8963bb..df67aa95df 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -1169,7 +1169,7 @@ public void testModelWithInvalidCustomPythonDependency() Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.BAD_REQUEST); Assert.assertEquals( resp.getMessage(), - "Custom pip package installation failed for custom_invalid_python_dep"); + "Custom pip package installation failed for model custom_invalid_python_dep"); TestUtils.setConfiguration(configManager, "install_py_dep_per_model", "false"); channel.close().sync(); } diff --git a/frontend/server/src/test/java/org/pytorch/serve/WorkflowTest.java b/frontend/server/src/test/java/org/pytorch/serve/WorkflowTest.java index 68a025065e..9c530d2f8b 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/WorkflowTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/WorkflowTest.java @@ -424,7 +424,7 @@ public void testWorkflowWithInvalidCustomPythonDependencyModel() Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INTERNAL_SERVER_ERROR); Assert.assertEquals( resp.getMessage(), - "Workflow custom_invalid_python_dep has failed to register. Failures: [Workflow Node custom_invalid_python_dep__custom_invalid_python_dep failed to register. Details: Custom pip package installation failed for custom_invalid_python_dep__custom_invalid_python_dep]"); + "Workflow custom_invalid_python_dep has failed to register. Failures: [Workflow Node custom_invalid_python_dep__custom_invalid_python_dep failed to register. Details: Custom pip package installation failed for model custom_invalid_python_dep__custom_invalid_python_dep]"); TestUtils.setConfiguration(configManager, "install_py_dep_per_model", "false"); channel.close().sync(); } diff --git a/model-archiver/README.md b/model-archiver/README.md index ae9df6f880..ab18ba3ba9 100644 --- a/model-archiver/README.md +++ b/model-archiver/README.md @@ -161,19 +161,26 @@ For more details refer [default handler documentation](../docs/default_handlers. ### Config file -A model config yaml file. For example: +A model config yaml file. For example: ``` # TS frontend parameters -# See all supported parameters: https://github.com/pytorch/serve/blob/master/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java#L14 +# See all supported parameters: https://github.com/pytorch/serve/blob/master/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java#L14 minWorkers: 1 # default: #CPU or #GPU maxWorkers: 1 # default: #CPU or #GPU batchSize: 1 # default: 1 maxBatchDelay: 100 # default: 100 msec responseTimeout: 120 # default: 120 sec deviceType: cpu # cpu, gpu, neuron -deviceIds: [0,1,2,3] # gpu device ids allocated to this model. +deviceIds: [0,1,2,3] # gpu device ids allocated to this model. parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel. Default: empty +useVenv: Create python virtual environment when using python backend to install model dependencies + (if enabled globally using install_py_dep_per_model=true) and run workers for model loading + and inference. Note that, although creation of virtual environment adds a latency overhead + (approx. 2 to 3 seconds) during model load and disk space overhead (approx. 25M), overall + it can speed up load time and reduce disk utilization for models with custom dependencies + since it enables reusing custom packages(specified in requirements.txt) and their + supported dependencies that are already available in the base python environment. # See torchrun parameters: https://pytorch.org/docs/stable/elastic/run.html torchrun: diff --git a/test/pytest/test_data/custom_dependencies/config.properties b/test/pytest/test_data/custom_dependencies/config.properties new file mode 100644 index 0000000000..8a0c798580 --- /dev/null +++ b/test/pytest/test_data/custom_dependencies/config.properties @@ -0,0 +1 @@ +install_py_dep_per_model=true diff --git a/test/pytest/test_data/custom_dependencies/mnist_custom_dependencies_handler.py b/test/pytest/test_data/custom_dependencies/mnist_custom_dependencies_handler.py new file mode 100644 index 0000000000..4c203e7b78 --- /dev/null +++ b/test/pytest/test_data/custom_dependencies/mnist_custom_dependencies_handler.py @@ -0,0 +1,20 @@ +# import custom dependency to test that it has been installed +import matplotlib.pyplot as pyplt +from torchvision import transforms + +from ts.torch_handler.image_classifier import ImageClassifier + + +class MNISTDigitClassifier(ImageClassifier): + image_processing = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + + def __init__(self): + super(MNISTDigitClassifier, self).__init__() + + def postprocess(self, data): + result = data.argmax(1).tolist() + # test that custom dependency works + pyplt.plot(result) + return result diff --git a/test/pytest/test_data/custom_dependencies/model_config.yaml b/test/pytest/test_data/custom_dependencies/model_config.yaml new file mode 100644 index 0000000000..14224bbdc9 --- /dev/null +++ b/test/pytest/test_data/custom_dependencies/model_config.yaml @@ -0,0 +1 @@ +useVenv: true diff --git a/test/pytest/test_data/custom_dependencies/requirements.txt b/test/pytest/test_data/custom_dependencies/requirements.txt new file mode 100644 index 0000000000..6ccafc3f90 --- /dev/null +++ b/test/pytest/test_data/custom_dependencies/requirements.txt @@ -0,0 +1 @@ +matplotlib diff --git a/test/pytest/test_model_custom_dependencies.py b/test/pytest/test_model_custom_dependencies.py new file mode 100644 index 0000000000..4a8fec5151 --- /dev/null +++ b/test/pytest/test_model_custom_dependencies.py @@ -0,0 +1,320 @@ +import os +import pathlib +import subprocess + +import requests +import test_utils +from model_archiver import ModelArchiver, ModelArchiverConfig +from model_archiver.manifest_components.manifest import RuntimeType + + +def setup_module(module): + test_utils.torchserve_cleanup() + # Clean out custom model dependencies in base python environment + subprocess.run( + [ + "pip", + "uninstall", + "-y", + "-r", + str( + os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "requirements.txt", + ) + ), + ], + check=True, + ) + # Create model store directory + pathlib.Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True) + + +def teardown_module(module): + test_utils.torchserve_cleanup() + # Restore custom model dependencies in base python environment + subprocess.run( + [ + "pip", + "install", + "-r", + str( + os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "requirements.txt", + ) + ), + ], + check=True, + ) + + +def generate_model_archive(use_requirements=False, use_venv=False): + config = ModelArchiverConfig( + model_name="mnist_custom_dependencies", + handler=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "mnist_custom_dependencies_handler.py", + ), + runtime=RuntimeType.PYTHON.value, + model_file=os.path.join( + test_utils.REPO_ROOT, "examples", "image_classifier", "mnist", "mnist.py" + ), + serialized_file=os.path.join( + test_utils.REPO_ROOT, + "examples", + "image_classifier", + "mnist", + "mnist_cnn.pt", + ), + extra_files=None, + export_path=test_utils.MODEL_STORE, + force=True, + archive_format="no-archive", + version="1.0", + requirements_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "requirements.txt", + ) + if use_requirements + else None, + config_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "model_config.yaml", + ) + if use_venv + else None, + ) + + ModelArchiver.generate_model_archive(config) + + +def register_model_and_make_inference_request(expect_model_load_failure=False): + try: + resp = test_utils.register_model( + "mnist_custom_dependencies", "mnist_custom_dependencies" + ) + resp.raise_for_status() + except Exception as e: + if expect_model_load_failure: + return + else: + raise e + + if expect_model_load_failure: + raise Exception("Expected model load failure but model load succeeded") + + data_file = os.path.join( + test_utils.REPO_ROOT, + "examples", + "image_classifier", + "mnist", + "test_data", + "0.png", + ) + with open(data_file, "rb") as input_data: + resp = requests.post( + url=f"http://localhost:8080/predictions/mnist_custom_dependencies", + data=input_data, + ) + resp.raise_for_status() + + +def test_install_dependencies_to_target_directory_with_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=True, use_venv=False) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "config.properties", + ), + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=False) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_install_dependencies_to_target_directory_without_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=False, use_venv=False) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "config.properties", + ), + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_disable_install_dependencies_to_target_directory_with_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=True, use_venv=False) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=None, + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_disable_install_dependencies_to_target_directory_without_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=False, use_venv=False) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=None, + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_install_dependencies_to_venv_with_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=True, use_venv=True) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "config.properties", + ), + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=False) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_install_dependencies_to_venv_without_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=False, use_venv=True) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=os.path.join( + test_utils.REPO_ROOT, + "test", + "pytest", + "test_data", + "custom_dependencies", + "config.properties", + ), + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_disable_install_dependencies_to_venv_with_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=True, use_venv=True) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=None, + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + +def test_disable_install_dependencies_to_venv_without_requirements(): + # Torchserve cleanup + test_utils.stop_torchserve() + test_utils.delete_all_snapshots() + + try: + generate_model_archive(use_requirements=False, use_venv=True) + test_utils.start_torchserve( + model_store=test_utils.MODEL_STORE, + snapshot_file=None, + no_config_snapshots=True, + gen_mar=False, + ) + register_model_and_make_inference_request(expect_model_load_failure=True) + finally: + test_utils.stop_torchserve() + test_utils.delete_all_snapshots()