Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with vLLM with tensor_parallel_size argument #805

Merged
merged 10 commits into from
Jul 23, 2024
33 changes: 33 additions & 0 deletions docs/sections/how_to_guides/advanced/scaling_with_ray.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,36 @@ ray job submit --address http://localhost:8265 --working-dir ray-pipeline -- pyt
1. In this case, we just want two nodes: one to run the Ray head node and one to run a worker.
2. We just want to run a task per node i.e. the Ray command that starts the head/worker node.
3. We have selected 1 GPU per node, but we could have selected more depending on the pipeline.

## `vLLM` and `tensor_parallel_size`

In order to use `vLLM` multi-GPU and multi-node capabilities with `ray`, we need to do a few changes in the example pipeline from above. The first change needed is to specify a value for `tensor_parallel_size` aka "In how many GPUs do I want you to load the model", and the second one is to define `ray` as the `distributed_executor_backend` as the default one in `vLLM` is to use `multiprocessing`:


```python
with Pipeline(name="text-generation-ray-pipeline") as pipeline:
load_data_from_hub = LoadDataFromHub(output_mappings={"prompt": "instruction"})

text_generation = TextGeneration(
llm=vLLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
extra_kwargs={
"tensor_parallel_size": 8,
"distributed_executor_backend": "ray",
}
)
)

load_data_from_hub >> text_generation
```

Finally, we need to define two environment variables in our `runtime_env.yaml` file:

```yaml
env_vars:
VLLM_USE_RAY_COMPILED_DAG: "1"
VLLM_USE_RAY_SPMD_WORKER: "1"
```

More information about distributed inference with `vLLM` can be found here: [vLLM - Distributed Serving](https://docs.vllm.ai/en/latest/serving/distributed_serving.html)
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"]
vllm = [
"vllm >= 0.5.3",
"filelock >= 3.13.4",
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
"setuptools",
]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Expand Down
6 changes: 3 additions & 3 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def agenerate( # type: ignore
input: a single input in chat format to generate responses for.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
new tokens based on their existing frequency in the text so far, decreasing
model's likelihood to repeat the same line verbatim. Defauls to `None`.
logit_bias: modify the likelihood of specified tokens appearing in the completion.
Expand Down Expand Up @@ -545,8 +545,8 @@ async def agenerate( # type: ignore
only if `tokenizer_id` is `None`. Defaults to `None`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
do_sample: whether to use sampling for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id` is not
`None`. Defaults to `False`.
of the `text_generation` method and will be only used if `tokenizer_id` is not
`None`. Defaults to `False`.
repetition_penalty: the repetition penalty to use for the generation. This argument
is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `None`.
Expand Down
42 changes: 38 additions & 4 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import multiprocessing as mp
import signal
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
from multiprocessing.pool import Pool
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast

import tblib

Expand Down Expand Up @@ -48,6 +49,40 @@ def _init_worker(log_queue: "Queue[Any]") -> None:
setup_logging(log_queue)


# We create a custom `Pool` class so the created processes are not daemons, allowing
# them to create child processes if necessary (for example when using `vLLM` with `tensor_parallel_size`)
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
class _NoDaemonProcess(mp.Process):
@property
def daemon(self) -> bool:
return False

@daemon.setter
def daemon(self, value: bool) -> None: # type: ignore
pass


class _NoDaemonContext(type(mp.get_context())):
Process = _NoDaemonProcess


class _NoDaemonPool(Pool):
def __init__(
self,
processes: Union[int, None] = None,
initializer: Union[Callable[..., object], None] = None,
initargs: Iterable[Any] = ..., # type: ignore
maxtasksperchild: Union[int, None] = None,
) -> None:
super().__init__(
processes=processes,
initializer=initializer,
initargs=initargs,
maxtasksperchild=maxtasksperchild,
context=_NoDaemonContext(), # type: ignore
)


class Pipeline(BasePipeline):
"""Local pipeline implementation using `multiprocessing`."""

Expand Down Expand Up @@ -133,10 +168,9 @@ def run(
return distiset

num_processes = self.dag.get_total_replica_count()
ctx = mp.get_context() # type: ignore
with (
ctx.Manager() as manager,
ctx.Pool(
mp.Manager() as manager,
_NoDaemonPool(
num_processes,
initializer=_init_worker,
initargs=(self._log_queue,),
Expand Down
Loading