Skip to content

Commit

Permalink
Add vLLM TPU example RayService manifest (#3000)
Browse files Browse the repository at this point in the history
* Add vLLM TPU example RayService

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

* Add comments

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

* Lint

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
  • Loading branch information
ryanaoleary authored Feb 12, 2025
1 parent 9559227 commit b227924
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
apiVersion: ray.io/v1
kind: RayService
metadata:
name: vllm-tpu
spec:
serveConfigV2: |
applications:
- name: llm
import_path: ray-operator.config.samples.vllm.serve_tpu:model
deployments:
- name: VLLMDeployment
num_replicas: 1
runtime_env:
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
env_vars:
MODEL_ID: "$MODEL_ID"
MAX_MODEL_LEN: "$MAX_MODEL_LEN"
DTYPE: "$DTYPE"
TOKENIZER_MODE: "$TOKENIZER_MODE"
TPU_CHIPS: "8"
rayClusterConfig:
headGroupSpec:
rayStartParams: {}
template:
metadata:
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: "0"
gke-gcsfuse/ephemeral-storage-limit: "0"
spec:
# replace $KSA_NAME with your Kubernetes Service Account
serviceAccountName: $KSA_NAME
containers:
- name: ray-head
# replace $VLLM_IMAGE with your vLLM container image
image: $VLLM_IMAGE
imagePullPolicy: IfNotPresent
ports:
- containerPort: 6379
name: gcs
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
- containerPort: 8000
name: serve
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
- name: VLLM_XLA_CACHE_PATH
value: "/data"
resources:
limits:
cpu: "2"
memory: 8G
requests:
cpu: "2"
memory: 8G
volumeMounts:
- name: gcs-fuse-csi-ephemeral
mountPath: /data
- name: dshm
mountPath: /dev/shm
volumes:
- name: gke-gcsfuse-cache
emptyDir:
medium: Memory
- name: dshm
emptyDir:
medium: Memory
- name: gcs-fuse-csi-ephemeral
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
# replace $GSBUCKET with your GCS bucket
bucketName: $GSBUCKET
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
workerGroupSpecs:
- groupName: tpu-group
replicas: 1
minReplicas: 1
maxReplicas: 1
numOfHosts: 1
rayStartParams: {}
template:
metadata:
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: "0"
gke-gcsfuse/ephemeral-storage-limit: "0"
spec:
# replace $KSA_NAME with your Kubernetes Service Account
serviceAccountName: $KSA_NAME
containers:
- name: ray-worker
# replace $VLLM_IMAGE with your vLLM container image
image: $VLLM_IMAGE
imagePullPolicy: IfNotPresent
resources:
limits:
cpu: "100"
google.com/tpu: "8"
ephemeral-storage: 40G
memory: 200G
requests:
cpu: "100"
google.com/tpu: "8"
ephemeral-storage: 40G
memory: 200G
env:
- name: JAX_PLATFORMS
value: "tpu"
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
- name: VLLM_XLA_CACHE_PATH
value: "/data"
volumeMounts:
- name: gcs-fuse-csi-ephemeral
mountPath: /data
- name: dshm
mountPath: /dev/shm
volumes:
- name: gke-gcsfuse-cache
emptyDir:
medium: Memory
- name: dshm
emptyDir:
medium: Memory
- name: gcs-fuse-csi-ephemeral
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
# replace $GSBUCKET with your GCS bucket
bucketName: $GSBUCKET
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x4
109 changes: 109 additions & 0 deletions ray-operator/config/samples/vllm/serve_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os

import json
import logging
from typing import Dict, List, Optional

import ray
from fastapi import FastAPI
from ray import serve
from starlette.requests import Request
from starlette.responses import Response

from vllm import LLM, SamplingParams

logger = logging.getLogger("ray.serve")

app = FastAPI()

@serve.deployment(name="VLLMDeployment")
@serve.ingress(app)
class VLLMDeployment:
def __init__(
self,
model_id,
num_tpu_chips,
max_model_len,
tokenizer_mode,
dtype,
):
self.llm = LLM(
model=model_id,
tensor_parallel_size=num_tpu_chips,
max_model_len=max_model_len,
dtype=dtype,
download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided.
tokenizer_mode=tokenizer_mode,
enforce_eager=True,
)

@app.post("/v1/generate")
async def generate(self, request: Request):
request_dict = await request.json()
prompts = request_dict.pop("prompt")
max_toks = int(request_dict.pop("max_tokens"))
print("Processing prompt ", prompts)
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=1,
max_tokens=max_toks)

outputs = self.llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = ""
token_ids = []
for completion_output in output.outputs:
generated_text += completion_output.text
token_ids.extend(list(completion_output.token_ids))

print("Generated text: ", generated_text)
ret = {
"prompt": prompt,
"text": generated_text,
"token_ids": token_ids,
}

return Response(content=json.dumps(ret))

def get_num_tpu_chips() -> int:
if "TPU" not in ray.cluster_resources():
# Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling).
if os.environ.get('TPU_CHIPS') is not None:
return int(os.environ.get('TPU_CHIPS'))
return 0
return int(ray.cluster_resources()["TPU"])

def get_max_model_len() -> Optional[int]:
if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "":
return None
return int(os.environ['MAX_MODEL_LEN'])

def get_tokenizer_mode() -> str:
if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "":
return "auto"
return os.environ['TOKENIZER_MODE']

def get_dtype() -> str:
if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "":
return "auto"
return os.environ['DTYPE']

def build_app(cli_args: Dict[str, str]) -> serve.Application:
"""Builds the Serve app based on CLI arguments."""
ray.init(ignore_reinit_error=True, address="ray://localhost:10001")

model_id = os.environ['MODEL_ID']

num_tpu_chips = get_num_tpu_chips()
pg_resources = []
pg_resources.append({"CPU": 1}) # for the deployment replica
for i in range(num_tpu_chips):
pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors

# Use PACK strategy since the deployment may use more than one TPU node.
return VLLMDeployment.options(
placement_group_bundles=pg_resources,
placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype())

model = build_app({})

0 comments on commit b227924

Please sign in to comment.