From 54bc4a087cd9c57373333d1769db7b545b7d916c Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Fri, 12 Jul 2024 17:35:01 +0000 Subject: [PATCH 1/6] Introduce pipeline parallel degree config --- engines/python/setup/djl_python/arg_parser.py | 5 ++++ .../properties_manager/hf_properties.py | 8 +++-- .../lmi_dist_rb_properties.py | 1 + .../rolling_batch/lmi_dist_rolling_batch.py | 3 ++ engines/python/setup/djl_python_engine.py | 3 ++ .../java/ai/djl/python/engine/Connection.java | 29 +++++++++++++----- .../main/java/ai/djl/python/engine/PyEnv.java | 30 ++++++++++++++++++- .../java/ai/djl/python/engine/PyModel.java | 7 +++++ .../java/ai/djl/python/engine/PyProcess.java | 9 +++--- .../docker/partition/properties_manager.py | 10 +++++++ serving/docker/partition/trt_llm_partition.py | 5 ++++ .../djl/serving/wlm/LmiConfigRecommender.java | 9 ++++++ .../java/ai/djl/serving/wlm/LmiUtils.java | 14 +++++++-- 13 files changed, 114 insertions(+), 19 deletions(-) diff --git a/engines/python/setup/djl_python/arg_parser.py b/engines/python/setup/djl_python/arg_parser.py index c0cd45f60..f2fa41889 100644 --- a/engines/python/setup/djl_python/arg_parser.py +++ b/engines/python/setup/djl_python/arg_parser.py @@ -71,6 +71,11 @@ def python_engine_args(): dest="tensor_parallel_degree", type=int, help='The tensor parallel degree') + parser.add_argument('--pipeline-parallel-degree', + required=False, + dest="pipeline_parallel_degree", + type=int, + help='The pipeline parallel degree') parser.add_argument('--cluster-size', required=False, dest="cluster_size", diff --git a/engines/python/setup/djl_python/properties_manager/hf_properties.py b/engines/python/setup/djl_python/properties_manager/hf_properties.py index c10e2d51d..4a61bc90a 100644 --- a/engines/python/setup/djl_python/properties_manager/hf_properties.py +++ b/engines/python/setup/djl_python/properties_manager/hf_properties.py @@ -50,6 +50,7 @@ class HuggingFaceProperties(Properties): device_id: int = -1 task: str = None tensor_parallel_degree: int = -1 + pipeline_parallel_degree: int = -1 cluster_size: int = 1 device_map: str = None load_in_4bit: Optional[bool] = None @@ -120,14 +121,15 @@ def construct_kwargs_device_map(self): self.device = None logging.info(f"Using device map {self.device_map}") elif self.tensor_parallel_degree > 0 \ + and self.pipeline_parallel_degree > 0 \ and self.cluster_size > 0 \ and torch.cuda.device_count() > 0: self.kwargs["device_map"] = "auto" self.device = None world_size = torch.cuda.device_count() * self.cluster_size - assert world_size == self.tensor_parallel_degree, \ - f"TP degree ({self.tensor_parallel_degree}) doesn't match available GPUs ({world_size})" - logging.info(f"Using {world_size} gpus") + assert world_size == self.tensor_parallel_degree*self.pipeline_parallel_degree, \ + f"TP*PP degree ({self.tensor_parallel_degree*self.pipeline_parallel_degree}) doesn't match available GPUs ({world_size})" + logging.info(f"Using {world_size} gpus collectively.") return self @model_validator(mode='after') diff --git a/engines/python/setup/djl_python/properties_manager/lmi_dist_rb_properties.py b/engines/python/setup/djl_python/properties_manager/lmi_dist_rb_properties.py index 6c9287105..c12ce1e83 100644 --- a/engines/python/setup/djl_python/properties_manager/lmi_dist_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/lmi_dist_rb_properties.py @@ -35,6 +35,7 @@ class LmiDistRbProperties(Properties): load_format: Optional[str] = "auto" quantize: Optional[LmiDistQuantizeMethods] = None tensor_parallel_degree: Optional[int] = None + pipeline_parallel_degree: Optional[int] = None max_rolling_batch_prefill_tokens: Optional[int] = None # Adjustable prefix model length for certain 32k or longer model max_model_len: Optional[int] = None diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 644c550ff..0763f36b3 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -62,6 +62,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): engine_args = VllmEngineArgs( model=self.lmi_dist_config.model_id_or_path, tensor_parallel_size=self.lmi_dist_config.tensor_parallel_degree, + pipeline_parallel_size=self.lmi_dist_config.pipeline_parallel_degree, dtype=DTYPE_MAPPER[self.lmi_dist_config.dtype], seed=0, max_model_len=self.lmi_dist_config.max_model_len, @@ -81,6 +82,8 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): **engine_kwargs) kwargs = {} + print(f"engine_args: {engine_args}, kwargs={kwargs}") + if self.lmi_dist_config.max_rolling_batch_prefill_tokens is None: kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS self.engine = engine_from_args(engine_args, **kwargs) diff --git a/engines/python/setup/djl_python_engine.py b/engines/python/setup/djl_python_engine.py index 34e3b8cf2..18261a707 100644 --- a/engines/python/setup/djl_python_engine.py +++ b/engines/python/setup/djl_python_engine.py @@ -52,6 +52,7 @@ def __init__(self, args, service): self.service = service self.device_id = args.device_id self.tensor_parallel_degree = args.tensor_parallel_degree + self.pipeline_parallel_degree = args.pipeline_parallel_degree self.cluster_size = args.cluster_size self.entry_point = args.entry_point self.recommended_entry_point = args.recommended_entry_point @@ -123,6 +124,8 @@ def run_server(self): prop = inputs.get_properties() if self.tensor_parallel_degree: prop["tensor_parallel_degree"] = self.tensor_parallel_degree + if self.pipeline_parallel_degree: + prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree if self.cluster_size: prop["cluster_size"] = self.cluster_size prop["device_id"] = self.device_id diff --git a/engines/python/src/main/java/ai/djl/python/engine/Connection.java b/engines/python/src/main/java/ai/djl/python/engine/Connection.java index 396ae4b84..835ba8ea8 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/Connection.java +++ b/engines/python/src/main/java/ai/djl/python/engine/Connection.java @@ -126,12 +126,23 @@ static String[] getPythonStartCmd( int deviceId = device.getDeviceId(); int clusterSize = PyEnv.getClusterSize(); int tensorParallelDegree = pyEnv.getTensorParallelDegree(); + int pipelineParallelDegree = pyEnv.getPipelineParallelDegree(); String entryPoint = pyEnv.getEntryPoint(); String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint(); if (PyEnv.isMultiNode()) { - String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree / clusterSize); + + int mpiSlots = (tensorParallelDegree*pipelineParallelDegree) / 2; + int mpiProcesses = tensorParallelDegree*pipelineParallelDegree; + + // if (pipelineParallelDegree != clusterSize) { + // logger.warn("In multi-node setting, pipeline parallel degree must equal the cluster size. Setting pp degree = cluster size = {}.", clusterSize); + // pipelineParallelDegree = clusterSize; + // } + + String cudaDevices = getVisibleDevices(workerId, mpiSlots); logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices); + logger.info("Received: pp degree: {} and tp depgree: {} and cluster size: {}", pipelineParallelDegree, tensorParallelDegree, clusterSize); StringBuilder sb = new StringBuilder(); boolean first = true; for (String host : hosts) { @@ -140,12 +151,12 @@ static String[] getPythonStartCmd( } else { sb.append(','); } - sb.append(host).append(':').append(tensorParallelDegree / clusterSize); + sb.append(host).append(':').append(mpiSlots); } - String[] args = new String[46]; + String[] args = new String[48]; args[0] = "mpirun"; args[1] = "-np"; - args[2] = String.valueOf(tensorParallelDegree); + args[2] = String.valueOf(mpiProcesses); args[3] = "--host"; args[4] = sb.toString(); args[5] = "--allow-run-as-root"; @@ -185,10 +196,12 @@ static String[] getPythonStartCmd( args[39] = String.valueOf(port); args[40] = "--tensor-parallel-degree"; args[41] = String.valueOf(tensorParallelDegree); - args[42] = "--cluster-size"; - args[43] = String.valueOf(clusterSize); - args[44] = "--recommended-entry-point"; - args[45] = recommendedEntryPoint == null ? "" : recommendedEntryPoint; + args[42] = "--pipeline-parallel-degree"; + args[43] = String.valueOf(pipelineParallelDegree); + args[44] = "--cluster-size"; + args[45] = String.valueOf(clusterSize); + args[46] = "--recommended-entry-point"; + args[47] = recommendedEntryPoint == null ? "" : recommendedEntryPoint; return args; } diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java index 8ecb53628..eeb7b7666 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java @@ -55,6 +55,7 @@ public class PyEnv { private int predictTimeout; private int modelLoadingTimeout; private int tensorParallelDegree; + private int pipelineParallelDegree; private Map envs; private Map initParameters; private boolean initialized; @@ -363,6 +364,33 @@ public void setTensorParallelDegree(int tensorParallelDegree) { this.tensorParallelDegree = tensorParallelDegree; } + /** + * Returns the pipeline parallel degree. + * + * @return the pipeline parallel degree + */ + public int getPipelineParallelDegree() { + if (pipelineParallelDegree == 0) { + String value = Utils.getenv("PIPELINE_PARALLEL_DEGREE"); + if (value != null) { + pipelineParallelDegree = Integer.parseInt(value); + } else { + pipelineParallelDegree = 1; + } + } + + return pipelineParallelDegree; + } + + /** + * Sets the pipeline parallel degree. + * + * @param pipelineParallelDegree the pipeline parallel degree + */ + public void setPipelineParallelDegree(int pipelineParallelDegree) { + this.pipelineParallelDegree = pipelineParallelDegree; + } + int getMpiWorkers() { int gpuCount = CudaUtils.getGpuCount() * clusterSize; String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES"); @@ -373,7 +401,7 @@ int getMpiWorkers() { } gpuCount = visibleCount; } - return gpuCount / getTensorParallelDegree(); + return gpuCount / (getTensorParallelDegree() * getPipelineParallelDegree()); } /** diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index ee228b466..f6f43c96a 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -130,6 +130,13 @@ public void load(Path modelPath, String prefix, Map options) throws I pyEnv.setTensorParallelDegree(Integer.parseInt(value)); } break; + case "pipeline_parallel_degree": + if(value != null) { + pyEnv.setPipelineParallelDegree(Integer.parseInt(value)); + } else { + pyEnv.setPipelineParallelDegree(1); + } + break; case "handler": pyEnv.setHandler(value); break; diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index 046199253..a38ea2967 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -63,21 +63,22 @@ class PyProcess { int port = counter.getAndIncrement(); if (pyEnv.isMpiMode()) { int tensorParallelDegree = pyEnv.getTensorParallelDegree(); + int pipelineParallelDegree = pyEnv.getPipelineParallelDegree(); int clusterSize = PyEnv.getClusterSize(); - connections = new ArrayList<>(tensorParallelDegree); + connections = new ArrayList<>(tensorParallelDegree*pipelineParallelDegree); if (clusterSize > 1) { hosts = getHosts(clusterSize); - for (int i = 0; i < tensorParallelDegree; ++i) { + for (int i = 0; i < tensorParallelDegree*pipelineParallelDegree; ++i) { connections.add( new Connection( pyEnv, port, i, - hosts[i / (tensorParallelDegree / clusterSize)])); + hosts[i / (tensorParallelDegree*pipelineParallelDegree / clusterSize)])); } } else { - for (int i = 0; i < tensorParallelDegree; ++i) { + for (int i = 0; i < tensorParallelDegree*pipelineParallelDegree; ++i) { connections.add(new Connection(pyEnv, port, i, "127.0.0.1")); } } diff --git a/serving/docker/partition/properties_manager.py b/serving/docker/partition/properties_manager.py index 50b3af62b..06ed4fe75 100644 --- a/serving/docker/partition/properties_manager.py +++ b/serving/docker/partition/properties_manager.py @@ -45,6 +45,9 @@ def __init__(self, args, **kwargs): if args.tensor_parallel_degree: self.properties[ 'option.tensor_parallel_degree'] = args.tensor_parallel_degree + if args.pipeline_parallel_degree: + self.properties[ + 'option.pipeline_parallel_degree'] = args.pipeline_parallel_degree if args.quantize: self.properties['option.quantize'] = args.quantize @@ -57,6 +60,7 @@ def __init__(self, args, **kwargs): if self.is_mpi_mode: self.validate_tp_degree() + self.validate_pp_degree() self.set_and_validate_entry_point() self.set_and_validate_save_mp_checkpoint_path() @@ -144,6 +148,12 @@ def validate_tp_degree(self): f'GPU devices are not enough to run {tensor_parallel_degree} partitions.' ) + def validate_pp_degree(self): + pipeline_parallel_degree = self.properties.get( + 'option.pipeline_parallel_degree') + if not pipeline_parallel_degree: + raise ValueError('Expecting pipeline_parallel_degree to be set of a default of 1') + def set_and_validate_entry_point(self): entry_point = self.properties.get('option.entryPoint') quantize = self.properties.get('option.quantize') diff --git a/serving/docker/partition/trt_llm_partition.py b/serving/docker/partition/trt_llm_partition.py index 9284fa590..151c36e5b 100644 --- a/serving/docker/partition/trt_llm_partition.py +++ b/serving/docker/partition/trt_llm_partition.py @@ -24,6 +24,7 @@ def create_trt_llm_repo(properties, args): kwargs = remove_option_from_properties(properties) kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo kwargs["tensor_parallel_degree"] = args.tensor_parallel_degree + kwargs["pipeline_parallel_degree"] = args.pipeline_parallel_degree model_id_or_path = args.model_path or kwargs['model_id'] create_model_repo(model_id_or_path, **kwargs) @@ -48,6 +49,10 @@ def main(): type=int, required=True, help='Tensor parallel degree') + parser.add_argument('--pipeline_parallel_degree', + type=int, + required=True, + help='Pipeline parallel degree') parser.add_argument('--model_path', type=str, required=False, diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index 9bc25c11f..b2afea0c9 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -84,6 +84,7 @@ static void configure( setRollingBatch(lmiProperties, modelConfig, features); setMpiMode(lmiProperties, modelConfig, features); setTensorParallelDegree(lmiProperties); + setPipelineParallelDegree(lmiProperties); setRollingBatchSize(lmiProperties); } @@ -149,6 +150,14 @@ private static void setTensorParallelDegree(Properties lmiProperties) { lmiProperties.setProperty("option.tensor_parallel_degree", tpDegree); } + private static void setPipelineParallelDegree(Properties lmiProperties) { + if (lmiProperties.containsKey("option.pipeline_parallel_degree")) { + return; + } + String ppDegree = Utils.getenv("PIPELINE_PARALLEL_DEGREE", "1"); + lmiProperties.setProperty("option.pipeline_parallel_degree", ppDegree); + } + private static void setDynamicBatch( Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig, diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 259a0bd53..97440f327 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -110,6 +110,7 @@ static void convertTrtLLM(ModelInfo info) throws IOException { if (modelId == null) { modelId = trtRepo.toString(); } + String tpDegree = info.prop.getProperty("option.tensor_parallel_degree"); if (tpDegree == null) { tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max"); @@ -118,6 +119,11 @@ static void convertTrtLLM(ModelInfo info) throws IOException { tpDegree = String.valueOf(CudaUtils.getGpuCount()); } + String ppDegree = info.prop.getProperty("option.pipeline_parallel_degree"); + if (ppDegree == null) { + ppDegree = Utils.getenv("PIPELINE_PARALLEL_DEGREE", "1"); + } + // TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching. if (info.prop.containsKey("trtllm_python_backend")) { // Inflight batching support is not available for certain models like t5. @@ -125,12 +131,12 @@ static void convertTrtLLM(ModelInfo info) throws IOException { // And whether it is valid or not is checked in tensorrt_llm_toolkit. So it is not // necessary to check here. if (!isValidTrtLlmPythonModelRepo(trtRepo)) { - info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree); + info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree, ppDegree); } } else { info.prop.put("option.rolling_batch", "trtllm"); if (!isValidTrtLlmModelRepo(trtRepo)) { - info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree); + info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree, ppDegree); } } } @@ -308,7 +314,7 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo } } - private static Path buildTrtLlmArtifacts(Path modelDir, String modelId, String tpDegree) + private static Path buildTrtLlmArtifacts(Path modelDir, String modelId, String tpDegree, String ppDegree) throws IOException { logger.info("Converting model to TensorRT-LLM artifacts"); String hash = Utils.hash(modelId + tpDegree); @@ -329,6 +335,8 @@ private static Path buildTrtLlmArtifacts(Path modelDir, String modelId, String t trtLlmRepoDir.toString(), "--tensor_parallel_degree", tpDegree, + "--pipeline_parallel_degree", + ppDegree, "--model_path", modelId }; From f3fbb33ed80859a655df4d96f545fa0fc577c175 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Mon, 15 Jul 2024 17:45:20 +0000 Subject: [PATCH 2/6] Address comment --- .../rolling_batch/lmi_dist_rolling_batch.py | 2 +- .../java/ai/djl/python/engine/Connection.java | 15 +++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 0763f36b3..51dd1f006 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -82,7 +82,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): **engine_kwargs) kwargs = {} - print(f"engine_args: {engine_args}, kwargs={kwargs}") + logging.info(f"engine_args: {engine_args}, kwargs: {kwargs}") if self.lmi_dist_config.max_rolling_batch_prefill_tokens is None: kwargs["warmup_prefill_tokens"] = _WARMUP_PREFILL_TOKENS diff --git a/engines/python/src/main/java/ai/djl/python/engine/Connection.java b/engines/python/src/main/java/ai/djl/python/engine/Connection.java index 835ba8ea8..46fad2ef0 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/Connection.java +++ b/engines/python/src/main/java/ai/djl/python/engine/Connection.java @@ -132,15 +132,10 @@ static String[] getPythonStartCmd( if (PyEnv.isMultiNode()) { - int mpiSlots = (tensorParallelDegree*pipelineParallelDegree) / 2; - int mpiProcesses = tensorParallelDegree*pipelineParallelDegree; + int localSize = (tensorParallelDegree*pipelineParallelDegree) / clusterSize; + int worldSize = tensorParallelDegree*pipelineParallelDegree; - // if (pipelineParallelDegree != clusterSize) { - // logger.warn("In multi-node setting, pipeline parallel degree must equal the cluster size. Setting pp degree = cluster size = {}.", clusterSize); - // pipelineParallelDegree = clusterSize; - // } - - String cudaDevices = getVisibleDevices(workerId, mpiSlots); + String cudaDevices = getVisibleDevices(workerId, localSize); logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices); logger.info("Received: pp degree: {} and tp depgree: {} and cluster size: {}", pipelineParallelDegree, tensorParallelDegree, clusterSize); StringBuilder sb = new StringBuilder(); @@ -151,12 +146,12 @@ static String[] getPythonStartCmd( } else { sb.append(','); } - sb.append(host).append(':').append(mpiSlots); + sb.append(host).append(':').append(localSize); } String[] args = new String[48]; args[0] = "mpirun"; args[1] = "-np"; - args[2] = String.valueOf(mpiProcesses); + args[2] = String.valueOf(worldSize); args[3] = "--host"; args[4] = sb.toString(); args[5] = "--allow-run-as-root"; From e202bb1c2de8a42076e8bd48701d1dd3a7072845 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 17 Jul 2024 19:12:40 +0000 Subject: [PATCH 3/6] Format python --- .../setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py | 3 ++- engines/python/setup/djl_python_engine.py | 3 ++- serving/docker/partition/properties_manager.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 51dd1f006..d6490c530 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -62,7 +62,8 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs): engine_args = VllmEngineArgs( model=self.lmi_dist_config.model_id_or_path, tensor_parallel_size=self.lmi_dist_config.tensor_parallel_degree, - pipeline_parallel_size=self.lmi_dist_config.pipeline_parallel_degree, + pipeline_parallel_size=self.lmi_dist_config. + pipeline_parallel_degree, dtype=DTYPE_MAPPER[self.lmi_dist_config.dtype], seed=0, max_model_len=self.lmi_dist_config.max_model_len, diff --git a/engines/python/setup/djl_python_engine.py b/engines/python/setup/djl_python_engine.py index 18261a707..4b3e0377d 100644 --- a/engines/python/setup/djl_python_engine.py +++ b/engines/python/setup/djl_python_engine.py @@ -125,7 +125,8 @@ def run_server(self): if self.tensor_parallel_degree: prop["tensor_parallel_degree"] = self.tensor_parallel_degree if self.pipeline_parallel_degree: - prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree + prop[ + "pipeline_parallel_degree"] = self.pipeline_parallel_degree if self.cluster_size: prop["cluster_size"] = self.cluster_size prop["device_id"] = self.device_id diff --git a/serving/docker/partition/properties_manager.py b/serving/docker/partition/properties_manager.py index 06ed4fe75..70d7128e9 100644 --- a/serving/docker/partition/properties_manager.py +++ b/serving/docker/partition/properties_manager.py @@ -152,7 +152,9 @@ def validate_pp_degree(self): pipeline_parallel_degree = self.properties.get( 'option.pipeline_parallel_degree') if not pipeline_parallel_degree: - raise ValueError('Expecting pipeline_parallel_degree to be set of a default of 1') + raise ValueError( + 'Expecting pipeline_parallel_degree to be set of a default of 1' + ) def set_and_validate_entry_point(self): entry_point = self.properties.get('option.entryPoint') From ab0f4c305a4acad6a2b7d8ec872d3a00718a15b6 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 17 Jul 2024 19:23:03 +0000 Subject: [PATCH 4/6] Format Java, python and add divisibility criteria --- .../java/ai/djl/python/engine/Connection.java | 19 ++++++++++++++++--- .../java/ai/djl/python/engine/PyModel.java | 2 +- .../java/ai/djl/python/engine/PyProcess.java | 12 ++++++++---- .../java/ai/djl/serving/wlm/LmiUtils.java | 4 ++-- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/engines/python/src/main/java/ai/djl/python/engine/Connection.java b/engines/python/src/main/java/ai/djl/python/engine/Connection.java index 46fad2ef0..cd6494f1f 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/Connection.java +++ b/engines/python/src/main/java/ai/djl/python/engine/Connection.java @@ -132,12 +132,25 @@ static String[] getPythonStartCmd( if (PyEnv.isMultiNode()) { - int localSize = (tensorParallelDegree*pipelineParallelDegree) / clusterSize; - int worldSize = tensorParallelDegree*pipelineParallelDegree; + int worldSize = tensorParallelDegree * pipelineParallelDegree; + + if (tensorParallelDegree * pipelineParallelDegree % clusterSize != 0) { + throw new IllegalArgumentException( + "Error: Cannot use cluster size: " + + clusterSize + + "for world size (number of total GPUs): " + + worldSize); + } + + int localSize = (tensorParallelDegree * pipelineParallelDegree) / clusterSize; String cudaDevices = getVisibleDevices(workerId, localSize); logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices); - logger.info("Received: pp degree: {} and tp depgree: {} and cluster size: {}", pipelineParallelDegree, tensorParallelDegree, clusterSize); + logger.info( + "Received: pp degree: {} and tp depgree: {} and cluster size: {}", + pipelineParallelDegree, + tensorParallelDegree, + clusterSize); StringBuilder sb = new StringBuilder(); boolean first = true; for (String host : hosts) { diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index f6f43c96a..2c3f368c9 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -131,7 +131,7 @@ public void load(Path modelPath, String prefix, Map options) throws I } break; case "pipeline_parallel_degree": - if(value != null) { + if (value != null) { pyEnv.setPipelineParallelDegree(Integer.parseInt(value)); } else { pyEnv.setPipelineParallelDegree(1); diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index a38ea2967..0d91b675e 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -65,20 +65,24 @@ class PyProcess { int tensorParallelDegree = pyEnv.getTensorParallelDegree(); int pipelineParallelDegree = pyEnv.getPipelineParallelDegree(); int clusterSize = PyEnv.getClusterSize(); - connections = new ArrayList<>(tensorParallelDegree*pipelineParallelDegree); + connections = new ArrayList<>(tensorParallelDegree * pipelineParallelDegree); if (clusterSize > 1) { hosts = getHosts(clusterSize); - for (int i = 0; i < tensorParallelDegree*pipelineParallelDegree; ++i) { + for (int i = 0; i < tensorParallelDegree * pipelineParallelDegree; ++i) { connections.add( new Connection( pyEnv, port, i, - hosts[i / (tensorParallelDegree*pipelineParallelDegree / clusterSize)])); + hosts[ + i + / (tensorParallelDegree + * pipelineParallelDegree + / clusterSize)])); } } else { - for (int i = 0; i < tensorParallelDegree*pipelineParallelDegree; ++i) { + for (int i = 0; i < tensorParallelDegree * pipelineParallelDegree; ++i) { connections.add(new Connection(pyEnv, port, i, "127.0.0.1")); } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 97440f327..92971e765 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -314,8 +314,8 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo } } - private static Path buildTrtLlmArtifacts(Path modelDir, String modelId, String tpDegree, String ppDegree) - throws IOException { + private static Path buildTrtLlmArtifacts( + Path modelDir, String modelId, String tpDegree, String ppDegree) throws IOException { logger.info("Converting model to TensorRT-LLM artifacts"); String hash = Utils.hash(modelId + tpDegree); String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null); From 3e91a41f630ddcfa9b3555e1ad115ddfb38256b6 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 17 Jul 2024 21:07:14 +0000 Subject: [PATCH 5/6] Remove assert --- .../setup/djl_python/properties_manager/hf_properties.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/engines/python/setup/djl_python/properties_manager/hf_properties.py b/engines/python/setup/djl_python/properties_manager/hf_properties.py index 4a61bc90a..b96f4be9a 100644 --- a/engines/python/setup/djl_python/properties_manager/hf_properties.py +++ b/engines/python/setup/djl_python/properties_manager/hf_properties.py @@ -127,8 +127,10 @@ def construct_kwargs_device_map(self): self.kwargs["device_map"] = "auto" self.device = None world_size = torch.cuda.device_count() * self.cluster_size - assert world_size == self.tensor_parallel_degree*self.pipeline_parallel_degree, \ - f"TP*PP degree ({self.tensor_parallel_degree*self.pipeline_parallel_degree}) doesn't match available GPUs ({world_size})" + + if world_size != self.tensor_parallel_degree*self.pipeline_parallel_degree: + logging.error(f"TP*PP degree ({self.tensor_parallel_degree*self.pipeline_parallel_degree}) doesn't match available GPUs ({world_size})") + logging.info(f"Using {world_size} gpus collectively.") return self From a8455c8571e849d88844ae018b05e3937b271689 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 17 Jul 2024 21:11:07 +0000 Subject: [PATCH 6/6] Remove assert --- .../setup/djl_python/properties_manager/hf_properties.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/engines/python/setup/djl_python/properties_manager/hf_properties.py b/engines/python/setup/djl_python/properties_manager/hf_properties.py index b96f4be9a..a1d1bee4d 100644 --- a/engines/python/setup/djl_python/properties_manager/hf_properties.py +++ b/engines/python/setup/djl_python/properties_manager/hf_properties.py @@ -128,8 +128,10 @@ def construct_kwargs_device_map(self): self.device = None world_size = torch.cuda.device_count() * self.cluster_size - if world_size != self.tensor_parallel_degree*self.pipeline_parallel_degree: - logging.error(f"TP*PP degree ({self.tensor_parallel_degree*self.pipeline_parallel_degree}) doesn't match available GPUs ({world_size})") + if world_size != self.tensor_parallel_degree * self.pipeline_parallel_degree: + raise ValueError( + f"TP*PP degree ({self.tensor_parallel_degree*self.pipeline_parallel_degree}) doesn't match available GPUs ({world_size})" + ) logging.info(f"Using {world_size} gpus collectively.") return self