Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add one example to run batch inference distributed on Ray #2696

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/offline_inference_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
This example shows how to use Ray Data for running offline batch inference
distributively on a multi-nodes cluster.

Learn more of Ray Data in https://docs.ray.io/en/latest/data/data.html
c21 marked this conversation as resolved.
Show resolved Hide resolved
"""

from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np
import ray

# Initialize Ray cluster with required dependencies.
ray.init(
runtime_env={
"pip": [
"accelerate>=0.16.0",
"transformers>=4.26.0",
"numpy<1.24", # remove when mlflow updates beyond 2.2
"torch",
]
c21 marked this conversation as resolved.
Show resolved Hide resolved
})

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)


# Create a class to do batch inference.
class LLMPredictor:

def __init__(self):
# Create an LLM.
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# Generate texts from the prompts.
# The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information.
outputs = self.llm.generate(batch["text"], sampling_params)
prompt = []
generated_text = []
for output in outputs:
prompt.append(output.prompt)
generated_text.append(' '.join([o.text for o in output.outputs]))
return {
"prompt": prompt,
"generated_text": generated_text,
}


# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")

# Apply batch inference for all input data.
ds = ds.map_batches(
LLMPredictor,
# Set the concurrency to the number of LLM instances.
concurrency=10,
# Specify the number of GPUs required per LLM instance.
num_gpus=1,
# Specify the batch size for inference.
batch_size=32,
)

# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
ds.write_parquet("s3://<your-output-bucket>")
c21 marked this conversation as resolved.
Show resolved Hide resolved
Loading