Skip to content

Commit

Permalink
[TPU] Support single and multi-host TPUs on GKE (#7613)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardsliu authored Aug 30, 2024
1 parent dc13e99 commit 2148441
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray[default]
5 changes: 4 additions & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def __init__(
raise NotImplementedError("TPU version must be 4 or higher.")

self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
tpu_type = tpu_type.lower()

if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
Expand Down
27 changes: 25 additions & 2 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

if current_platform.is_tpu():
import ray
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt

from vllm.executor import ray_utils


class TpuCommunicator:

Expand All @@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup):
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())

# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg

local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size

# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)

pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

Expand Down
15 changes: 15 additions & 0 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})

worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
Expand All @@ -81,6 +94,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
if override_env:
worker.override_env_vars.remote(override_env)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
Expand Down
29 changes: 29 additions & 0 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -84,6 +85,9 @@ def execute_model_spmd(

return output

def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)

ray_import_err = None

except ImportError as e:
Expand Down Expand Up @@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group


def get_num_tpu_nodes() -> int:
from ray._private.accelerators import TPUAcceleratorManager
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
assert total_tpus % tpus_per_node == 0
return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()
num_nodes = 0

if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)

return num_nodes

0 comments on commit 2148441

Please sign in to comment.