Skip to content

Commit

Permalink
Enable model custom dependency installation using virtual environment (
Browse files Browse the repository at this point in the history
…#2910)

* Explicitly set default dependency upgrade strategy to only-if-needed

* Add support for venv creation per model

* Create virtual env only when requirements need to be installed

* Format logger and exception messages

* Enable per model useVenv configuration option

* Add integration tests and documentation

* Fix integraiton test failures

* Revert "Fix integraiton test failures"

This reverts commit 6c4a279.

* Update integration test teardown functionality

* Refactor implementaiton to check for useVenv in Model.java

* Fix dependencyPath logic

* Refactor isUseVenv and integration tests

* Update documentation

---------

Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
  • Loading branch information
namannandan and agunapal authored Feb 9, 2024
1 parent ddeb027 commit ba8f96a
Show file tree
Hide file tree
Showing 12 changed files with 543 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ public class ModelConfig {
private int maxSequenceJobQueueSize = 1;
/** the max number of sequences can be accepted. The default value is 1. */
private int maxNumSequence = 1;

/** continuousBatching is a flag to enable continuous batching. */
private boolean continuousBatching;
/**
* Create python virtual environment when using python backend to install model dependencies (if
* enabled globally using configuration install_py_dep_per_model=true) and run workers for model
* loading and inference.
*/
private boolean useVenv;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
Expand Down Expand Up @@ -207,6 +212,13 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
v);
}
break;
case "useVenv":
if (v instanceof Boolean) {
modelConfig.setUseVenv((boolean) v);
} else {
logger.warn("Invalid useVenv: {}, should be true or false", v);
}
break;
default:
break;
}
Expand Down Expand Up @@ -379,6 +391,14 @@ public void setMaxNumSequence(int maxNumSequence) {
this.maxNumSequence = Math.max(1, maxNumSequence);
}

public boolean getUseVenv() {
return useVenv;
}

public void setUseVenv(boolean useVenv) {
this.useVenv = useVenv;
}

public enum ParallelType {
NONE(""),
PP("pp"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -74,15 +76,29 @@ public static String[] getEnvString(String cwd, String modelPath, String handler

public static String getPythonRunTime(Model model) {
String pythonRuntime;
Manifest.RuntimeType runtime = model.getModelArchive().getManifest().getRuntime();
Manifest.RuntimeType runtime = model.getRuntimeType();
if (runtime == Manifest.RuntimeType.PYTHON) {
pythonRuntime = configManager.getPythonExecutable();
Path pythonVenvRuntime =
Paths.get(getPythonVenvPath(model).toString(), "bin", "python");
if (model.isUseVenv() && Files.exists(pythonVenvRuntime)) {
pythonRuntime = pythonVenvRuntime.toString();
}
} else {
pythonRuntime = runtime.getValue();
}
return pythonRuntime;
}

public static File getPythonVenvPath(Model model) {
File modelDir = model.getModelDir();
if (Files.isSymbolicLink(modelDir.toPath())) {
modelDir = modelDir.getParentFile();
}
Path venvPath = Paths.get(modelDir.getAbsolutePath(), "venv").toAbsolutePath();
return venvPath.toFile();
}

public static String[] getCppEnvString(String libPath) {
ArrayList<String> envList = new ArrayList<>();
StringBuilder cppPath = new StringBuilder();
Expand Down
10 changes: 10 additions & 0 deletions frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ public class Model {
private boolean useJobTicket;
private AtomicInteger numJobTickets;
private boolean continuousBatching;
private boolean useVenv;

public Model(ModelArchive modelArchive, int queueSize) {
this.modelArchive = modelArchive;
if (modelArchive != null && modelArchive.getModelConfig() != null) {
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
useVenv = modelArchive.getModelConfig().getUseVenv();
if (modelArchive.getModelConfig().getParallelLevel() > 0
&& modelArchive.getModelConfig().getParallelType()
!= ModelConfig.ParallelType.NONE) {
Expand Down Expand Up @@ -636,6 +638,14 @@ public boolean isContinuousBatching() {
return continuousBatching;
}

public boolean isUseVenv() {
if (getRuntimeType() == Manifest.RuntimeType.PYTHON) {
return useVenv;
} else {
return false;
}
}

public boolean hasTensorParallel() {
switch (this.parallelType) {
case PP:
Expand Down
189 changes: 140 additions & 49 deletions frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ public void registerAndUpdateModel(String modelName, JsonObject modelInfo)

createVersionedModel(tempModel, versionId);

setupModelVenv(tempModel);

setupModelDependencies(tempModel);
if (defaultVersion) {
modelManager.setDefaultVersion(modelName, versionId);
Expand Down Expand Up @@ -153,6 +155,8 @@ public ModelArchive registerModel(
}
}

setupModelVenv(tempModel);

setupModelDependencies(tempModel);

logger.info("Model {} loaded.", tempModel.getModelName());
Expand Down Expand Up @@ -205,81 +209,168 @@ private ModelArchive createModelArchive(
return archive;
}

private void setupModelVenv(Model model)
throws IOException, InterruptedException, ModelException {
if (!model.isUseVenv()) {
return;
}

File venvPath = EnvironmentUtils.getPythonVenvPath(model);
List<String> commandParts = new ArrayList<>();
commandParts.add(configManager.getPythonExecutable());
commandParts.add("-m");
commandParts.add("venv");
commandParts.add("--clear");
commandParts.add("--system-site-packages");
commandParts.add(venvPath.toString());

ProcessBuilder processBuilder = new ProcessBuilder(commandParts);

if (isValidDependencyPath(venvPath)) {
processBuilder.directory(venvPath.getParentFile());
} else {
throw new ModelException(
"Invalid python venv path for model "
+ model.getModelName()
+ ": "
+ venvPath.toString());
}
Map<String, String> environment = processBuilder.environment();
String[] envp =
EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
if (parts.length == 2) {
environment.put(parts[0], parts[1]);
}
}
processBuilder.redirectErrorStream(true);

Process process = processBuilder.start();

int exitCode = process.waitFor();
String line;
StringBuilder outputString = new StringBuilder();
BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream()));
while ((line = brdr.readLine()) != null) {
outputString.append(line + "\n");
}

if (exitCode == 0) {
logger.info(
"Created virtual environment for model {}: {}",
model.getModelName(),
venvPath.toString());
} else {
logger.error(
"Virtual environment creation for model {} at {} failed:\n{}",
model.getModelName(),
venvPath.toString(),
outputString.toString());
throw new ModelException(
"Virtual environment creation failed for model " + model.getModelName());
}
}

private void setupModelDependencies(Model model)
throws IOException, InterruptedException, ModelException {
String requirementsFile =
model.getModelArchive().getManifest().getModel().getRequirementsFile();

if (configManager.getInstallPyDepPerModel() && requirementsFile != null) {
Path requirementsFilePath =
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile);
if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) {
return;
}

String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
Path requirementsFilePath =
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath();
List<String> commandParts = new ArrayList<>();
ProcessBuilder processBuilder = new ProcessBuilder();

if (model.isUseVenv()) {
if (!isValidDependencyPath(Paths.get(pythonRuntime).toFile())) {
throw new ModelException(
"Invalid python venv runtime path for model "
+ model.getModelName()
+ ": "
+ pythonRuntime);
}

String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
processBuilder.directory(EnvironmentUtils.getPythonVenvPath(model).getParentFile());

commandParts.add(pythonRuntime);
commandParts.add("-m");
commandParts.add("pip");
commandParts.add("install");
commandParts.add("-U");
commandParts.add("--upgrade-strategy");
commandParts.add("only-if-needed");
commandParts.add("-r");
commandParts.add(requirementsFilePath.toString());
} else {
File dependencyPath = model.getModelDir();
if (Files.isSymbolicLink(dependencyPath.toPath())) {
dependencyPath = dependencyPath.getParentFile();
}
dependencyPath = dependencyPath.getAbsoluteFile();
if (!isValidDependencyPath(dependencyPath)) {
throw new ModelException(
"Invalid 3rd party package installation path " + dependencyPath.toString());
}

List<String> commandParts = new ArrayList<>();
processBuilder.directory(dependencyPath);

commandParts.add(pythonRuntime);
commandParts.add("-m");
commandParts.add("pip");
commandParts.add("install");
commandParts.add("-U");
commandParts.add("-t");
commandParts.add(dependencyPath.getAbsolutePath());
commandParts.add(dependencyPath.toString());
commandParts.add("-r");
commandParts.add(requirementsFilePath.toString());
}

String[] envp =
EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);

ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
if (isValidDependencyPath(dependencyPath)) {
processBuilder.directory(dependencyPath);
} else {
throw new ModelException(
"Invalid 3rd party package installation path "
+ dependencyPath.getCanonicalPath());
processBuilder.command(commandParts);
String[] envp =
EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
Map<String, String> environment = processBuilder.environment();
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
if (parts.length == 2) {
environment.put(parts[0], parts[1]);
}
}
processBuilder.redirectErrorStream(true);

Map<String, String> environment = processBuilder.environment();
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
if (parts.length == 2) {
environment.put(parts[0], parts[1]);
}
}
Process process = processBuilder.start();
int exitCode = process.waitFor();

if (exitCode != 0) {

String line;
StringBuilder outputString = new StringBuilder();
// process's stdout is InputStream for caller process
BufferedReader brdr =
new BufferedReader(new InputStreamReader(process.getInputStream()));
while ((line = brdr.readLine()) != null) {
outputString.append(line);
}
StringBuilder errorString = new StringBuilder();
// process's stderr is ErrorStream for caller process
brdr = new BufferedReader(new InputStreamReader(process.getErrorStream()));
while ((line = brdr.readLine()) != null) {
errorString.append(line);
}
Process process = processBuilder.start();

logger.error("Dependency installation stderr:\n" + errorString.toString());
int exitCode = process.waitFor();
String line;
StringBuilder outputString = new StringBuilder();
BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream()));
while ((line = brdr.readLine()) != null) {
outputString.append(line + "\n");
}

throw new ModelException(
"Custom pip package installation failed for " + model.getModelName());
}
if (exitCode == 0) {
logger.info(
"Installed custom pip packages for model {}:\n{}",
model.getModelName(),
outputString.toString());
} else {
logger.error(
"Custom pip package installation failed for model {}:\n{}",
model.getModelName(),
outputString.toString());
throw new ModelException(
"Custom pip package installation failed for model " + model.getModelName());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ public void testModelWithInvalidCustomPythonDependency()
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.BAD_REQUEST);
Assert.assertEquals(
resp.getMessage(),
"Custom pip package installation failed for custom_invalid_python_dep");
"Custom pip package installation failed for model custom_invalid_python_dep");
TestUtils.setConfiguration(configManager, "install_py_dep_per_model", "false");
channel.close().sync();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ public void testWorkflowWithInvalidCustomPythonDependencyModel()
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INTERNAL_SERVER_ERROR);
Assert.assertEquals(
resp.getMessage(),
"Workflow custom_invalid_python_dep has failed to register. Failures: [Workflow Node custom_invalid_python_dep__custom_invalid_python_dep failed to register. Details: Custom pip package installation failed for custom_invalid_python_dep__custom_invalid_python_dep]");
"Workflow custom_invalid_python_dep has failed to register. Failures: [Workflow Node custom_invalid_python_dep__custom_invalid_python_dep failed to register. Details: Custom pip package installation failed for model custom_invalid_python_dep__custom_invalid_python_dep]");
TestUtils.setConfiguration(configManager, "install_py_dep_per_model", "false");
channel.close().sync();
}
Expand Down
13 changes: 10 additions & 3 deletions model-archiver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,26 @@ For more details refer [default handler documentation](../docs/default_handlers.

### Config file

A model config yaml file. For example:
A model config yaml file. For example:

```
# TS frontend parameters
# See all supported parameters: https://github.com/pytorch/serve/blob/master/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java#L14
# See all supported parameters: https://github.com/pytorch/serve/blob/master/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java#L14
minWorkers: 1 # default: #CPU or #GPU
maxWorkers: 1 # default: #CPU or #GPU
batchSize: 1 # default: 1
maxBatchDelay: 100 # default: 100 msec
responseTimeout: 120 # default: 120 sec
deviceType: cpu # cpu, gpu, neuron
deviceIds: [0,1,2,3] # gpu device ids allocated to this model.
deviceIds: [0,1,2,3] # gpu device ids allocated to this model.
parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel. Default: empty
useVenv: Create python virtual environment when using python backend to install model dependencies
(if enabled globally using install_py_dep_per_model=true) and run workers for model loading
and inference. Note that, although creation of virtual environment adds a latency overhead
(approx. 2 to 3 seconds) during model load and disk space overhead (approx. 25M), overall
it can speed up load time and reduce disk utilization for models with custom dependencies
since it enables reusing custom packages(specified in requirements.txt) and their
supported dependencies that are already available in the base python environment.
# See torchrun parameters: https://pytorch.org/docs/stable/elastic/run.html
torchrun:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
install_py_dep_per_model=true
Loading

0 comments on commit ba8f96a

Please sign in to comment.