-
Notifications
You must be signed in to change notification settings - Fork 466
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vLLM TPU example RayService manifest (#3000)
* Add vLLM TPU example RayService Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> * Add comments Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> * Lint Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
- Loading branch information
1 parent
9559227
commit b227924
Showing
2 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
ray-operator/config/samples/vllm/ray-service.vllm-tpu-v6e-singlehost.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
apiVersion: ray.io/v1 | ||
kind: RayService | ||
metadata: | ||
name: vllm-tpu | ||
spec: | ||
serveConfigV2: | | ||
applications: | ||
- name: llm | ||
import_path: ray-operator.config.samples.vllm.serve_tpu:model | ||
deployments: | ||
- name: VLLMDeployment | ||
num_replicas: 1 | ||
runtime_env: | ||
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip" | ||
env_vars: | ||
MODEL_ID: "$MODEL_ID" | ||
MAX_MODEL_LEN: "$MAX_MODEL_LEN" | ||
DTYPE: "$DTYPE" | ||
TOKENIZER_MODE: "$TOKENIZER_MODE" | ||
TPU_CHIPS: "8" | ||
rayClusterConfig: | ||
headGroupSpec: | ||
rayStartParams: {} | ||
template: | ||
metadata: | ||
annotations: | ||
gke-gcsfuse/volumes: "true" | ||
gke-gcsfuse/cpu-limit: "0" | ||
gke-gcsfuse/memory-limit: "0" | ||
gke-gcsfuse/ephemeral-storage-limit: "0" | ||
spec: | ||
# replace $KSA_NAME with your Kubernetes Service Account | ||
serviceAccountName: $KSA_NAME | ||
containers: | ||
- name: ray-head | ||
# replace $VLLM_IMAGE with your vLLM container image | ||
image: $VLLM_IMAGE | ||
imagePullPolicy: IfNotPresent | ||
ports: | ||
- containerPort: 6379 | ||
name: gcs | ||
- containerPort: 8265 | ||
name: dashboard | ||
- containerPort: 10001 | ||
name: client | ||
- containerPort: 8000 | ||
name: serve | ||
env: | ||
- name: HUGGING_FACE_HUB_TOKEN | ||
valueFrom: | ||
secretKeyRef: | ||
name: hf-secret | ||
key: hf_api_token | ||
- name: VLLM_XLA_CACHE_PATH | ||
value: "/data" | ||
resources: | ||
limits: | ||
cpu: "2" | ||
memory: 8G | ||
requests: | ||
cpu: "2" | ||
memory: 8G | ||
volumeMounts: | ||
- name: gcs-fuse-csi-ephemeral | ||
mountPath: /data | ||
- name: dshm | ||
mountPath: /dev/shm | ||
volumes: | ||
- name: gke-gcsfuse-cache | ||
emptyDir: | ||
medium: Memory | ||
- name: dshm | ||
emptyDir: | ||
medium: Memory | ||
- name: gcs-fuse-csi-ephemeral | ||
csi: | ||
driver: gcsfuse.csi.storage.gke.io | ||
volumeAttributes: | ||
# replace $GSBUCKET with your GCS bucket | ||
bucketName: $GSBUCKET | ||
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" | ||
workerGroupSpecs: | ||
- groupName: tpu-group | ||
replicas: 1 | ||
minReplicas: 1 | ||
maxReplicas: 1 | ||
numOfHosts: 1 | ||
rayStartParams: {} | ||
template: | ||
metadata: | ||
annotations: | ||
gke-gcsfuse/volumes: "true" | ||
gke-gcsfuse/cpu-limit: "0" | ||
gke-gcsfuse/memory-limit: "0" | ||
gke-gcsfuse/ephemeral-storage-limit: "0" | ||
spec: | ||
# replace $KSA_NAME with your Kubernetes Service Account | ||
serviceAccountName: $KSA_NAME | ||
containers: | ||
- name: ray-worker | ||
# replace $VLLM_IMAGE with your vLLM container image | ||
image: $VLLM_IMAGE | ||
imagePullPolicy: IfNotPresent | ||
resources: | ||
limits: | ||
cpu: "100" | ||
google.com/tpu: "8" | ||
ephemeral-storage: 40G | ||
memory: 200G | ||
requests: | ||
cpu: "100" | ||
google.com/tpu: "8" | ||
ephemeral-storage: 40G | ||
memory: 200G | ||
env: | ||
- name: JAX_PLATFORMS | ||
value: "tpu" | ||
- name: HUGGING_FACE_HUB_TOKEN | ||
valueFrom: | ||
secretKeyRef: | ||
name: hf-secret | ||
key: hf_api_token | ||
- name: VLLM_XLA_CACHE_PATH | ||
value: "/data" | ||
volumeMounts: | ||
- name: gcs-fuse-csi-ephemeral | ||
mountPath: /data | ||
- name: dshm | ||
mountPath: /dev/shm | ||
volumes: | ||
- name: gke-gcsfuse-cache | ||
emptyDir: | ||
medium: Memory | ||
- name: dshm | ||
emptyDir: | ||
medium: Memory | ||
- name: gcs-fuse-csi-ephemeral | ||
csi: | ||
driver: gcsfuse.csi.storage.gke.io | ||
volumeAttributes: | ||
# replace $GSBUCKET with your GCS bucket | ||
bucketName: $GSBUCKET | ||
mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" | ||
nodeSelector: | ||
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice | ||
cloud.google.com/gke-tpu-topology: 2x4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
|
||
import json | ||
import logging | ||
from typing import Dict, List, Optional | ||
|
||
import ray | ||
from fastapi import FastAPI | ||
from ray import serve | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
logger = logging.getLogger("ray.serve") | ||
|
||
app = FastAPI() | ||
|
||
@serve.deployment(name="VLLMDeployment") | ||
@serve.ingress(app) | ||
class VLLMDeployment: | ||
def __init__( | ||
self, | ||
model_id, | ||
num_tpu_chips, | ||
max_model_len, | ||
tokenizer_mode, | ||
dtype, | ||
): | ||
self.llm = LLM( | ||
model=model_id, | ||
tensor_parallel_size=num_tpu_chips, | ||
max_model_len=max_model_len, | ||
dtype=dtype, | ||
download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided. | ||
tokenizer_mode=tokenizer_mode, | ||
enforce_eager=True, | ||
) | ||
|
||
@app.post("/v1/generate") | ||
async def generate(self, request: Request): | ||
request_dict = await request.json() | ||
prompts = request_dict.pop("prompt") | ||
max_toks = int(request_dict.pop("max_tokens")) | ||
print("Processing prompt ", prompts) | ||
sampling_params = SamplingParams(temperature=0.7, | ||
top_p=1.0, | ||
n=1, | ||
max_tokens=max_toks) | ||
|
||
outputs = self.llm.generate(prompts, sampling_params) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = "" | ||
token_ids = [] | ||
for completion_output in output.outputs: | ||
generated_text += completion_output.text | ||
token_ids.extend(list(completion_output.token_ids)) | ||
|
||
print("Generated text: ", generated_text) | ||
ret = { | ||
"prompt": prompt, | ||
"text": generated_text, | ||
"token_ids": token_ids, | ||
} | ||
|
||
return Response(content=json.dumps(ret)) | ||
|
||
def get_num_tpu_chips() -> int: | ||
if "TPU" not in ray.cluster_resources(): | ||
# Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling). | ||
if os.environ.get('TPU_CHIPS') is not None: | ||
return int(os.environ.get('TPU_CHIPS')) | ||
return 0 | ||
return int(ray.cluster_resources()["TPU"]) | ||
|
||
def get_max_model_len() -> Optional[int]: | ||
if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "": | ||
return None | ||
return int(os.environ['MAX_MODEL_LEN']) | ||
|
||
def get_tokenizer_mode() -> str: | ||
if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "": | ||
return "auto" | ||
return os.environ['TOKENIZER_MODE'] | ||
|
||
def get_dtype() -> str: | ||
if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "": | ||
return "auto" | ||
return os.environ['DTYPE'] | ||
|
||
def build_app(cli_args: Dict[str, str]) -> serve.Application: | ||
"""Builds the Serve app based on CLI arguments.""" | ||
ray.init(ignore_reinit_error=True, address="ray://localhost:10001") | ||
|
||
model_id = os.environ['MODEL_ID'] | ||
|
||
num_tpu_chips = get_num_tpu_chips() | ||
pg_resources = [] | ||
pg_resources.append({"CPU": 1}) # for the deployment replica | ||
for i in range(num_tpu_chips): | ||
pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors | ||
|
||
# Use PACK strategy since the deployment may use more than one TPU node. | ||
return VLLMDeployment.options( | ||
placement_group_bundles=pg_resources, | ||
placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype()) | ||
|
||
model = build_app({}) |