Skip to content

Commit

Permalink
ENH: Some improvements for Xavier (#2777)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Jan 24, 2025
1 parent 1d070e7 commit 121c08a
Show file tree
Hide file tree
Showing 13 changed files with 700 additions and 253 deletions.
26 changes: 9 additions & 17 deletions doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Xinference \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2025-01-10 14:44+0800\n"
"POT-Creation-Date: 2025-01-23 14:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
Expand All @@ -31,9 +31,10 @@ msgid ""
" instances. This allows KV cache computed by other replicas to be "
"directly reused, avoiding redundant computations."
msgstr ""
"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会影响整体吞吐量和单次推理的延迟。"
"Xinference 通过引入 ``Xavier`` 框架来增强 vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。"
"这使得其他副本计算出的 KV 缓存可以被直接重用,从而避免了冗余计算。"
"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会"
"影响整体吞吐量和单次推理的延迟。Xinference 通过引入 ``Xavier`` 框架来增强"
" vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。这使得其他副本计算出的 "
"KV 缓存可以被直接重用,从而避免了冗余计算。"

#: ../../source/user_guide/vllm_enhancement.rst:15
msgid "Usage"
Expand All @@ -43,31 +44,22 @@ msgstr "使用"
msgid ""
"Simply add the parameter ``enable_xavier=True`` when starting the vllm "
"model."
msgstr ""
"启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。"
msgstr "启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。"

#: ../../source/user_guide/vllm_enhancement.rst:20
msgid "Limitations"
msgstr "限制"

#: ../../source/user_guide/vllm_enhancement.rst:21
msgid "Xavier requires vllm version >= ``0.6.5``."
msgstr ""
"Xavier 要求 vllm 版本不低于 ``0.6.5`` 。"
msgstr "Xavier 要求 vllm 版本不低于 ``0.6.5`` 。"

#: ../../source/user_guide/vllm_enhancement.rst:22
msgid ""
"Xavier is currently not compatible with model reloading after CUDA OOM in"
" Xinference. (it will be supported in the future)"
msgstr ""
"目前 Xavier 与 Xinference 中模型 CUDA OOM 后的重新拉起特性不兼容(未来将解决此问题)。"

#: ../../source/user_guide/vllm_enhancement.rst:23
msgid ""
"Due to the underlying communication not recognizing ``0.0.0.0``, the "
"actual IP address needs to be passed when starting Xinference, for "
"example: ``xinference-local -H 192.168.xx.xx``."
msgstr ""
"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 IP 地址,"
"例如:``xinference-local -H 192.168.xx.xx`` 。"
"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 "
"IP 地址,例如:``xinference-local -H 192.168.xx.xx`` 。"

1 change: 0 additions & 1 deletion doc/source/user_guide/vllm_enhancement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@ Simply add the parameter ``enable_xavier=True`` when starting the vllm model.
Limitations
***********
* Xavier requires vllm version >= ``0.6.5``.
* Xavier is currently not compatible with model reloading after CUDA OOM in Xinference. (it will be supported in the future)
* Due to the underlying communication not recognizing ``0.0.0.0``, the actual IP address needs to be passed when starting Xinference, for example: ``xinference-local -H 192.168.xx.xx``.
2 changes: 2 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
List,
Optional,
Union,
no_type_check,
)

import sse_starlette.sse
Expand Down Expand Up @@ -302,6 +303,7 @@ def __repr__(self) -> str:
def decrease_serve_count(self):
self._serve_count -= 1

@no_type_check
async def start_transfer_for_vllm(self, rank_addresses: List[str]):
from ..model.llm.vllm.core import VLLMModel
from ..model.llm.vllm.xavier.transfer import TransferActor
Expand Down
137 changes: 114 additions & 23 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ async def signal_handler():
)

from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager

self._block_tracker: Optional[xo.ActorRefType[VLLMBlockTracker]] = None
self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {}
self._collective_manager_mapping: Dict[
str, xo.ActorRefType[CollectiveManager]
] = {}

@typing.no_type_check
async def get_cluster_device_info(self, detailed: bool = False) -> List:
Expand Down Expand Up @@ -960,26 +964,40 @@ async def launch_builtin_model(
]:
raise ValueError("Tensorizer is not supported for %s." % model_name)

if model_uid is None:
model_uid = self._gen_model_uid(model_name)

# Xavier-related
enable_xavier: bool = (
bool(kwargs.pop("enable_xavier", False))
and model_engine is not None
and model_engine.lower() == "vllm"
)
store_address = None
store_port = None
world_size = None
if enable_xavier:
if replica <= 1:
logger.warning(f"Enabling xavier when `replica<=1` is meaningless.")
enable_xavier = False
else:
from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager

self._block_tracker = await xo.create_actor(
self._block_tracker_mapping[model_uid] = await xo.create_actor(
VLLMBlockTracker,
address=self.address,
uid=VLLMBlockTracker.default_uid(),
uid=f"{VLLMBlockTracker.default_uid()}-{model_uid}",
)

if model_uid is None:
model_uid = self._gen_model_uid(model_name)
world_size = replica + 1
logger.info(f"Going to start xavier with world size: {world_size}")
self._collective_manager_mapping[model_uid] = await xo.create_actor(
CollectiveManager,
address=self.address,
uid=f"{CollectiveManager.default_uid()}-{model_uid}",
model_uid=model_uid,
)
logger.info(f"Start collective manager for {model_uid} done.")

model_size = str(model_size_in_billions) if model_size_in_billions else ""
logger.debug(
Expand All @@ -988,13 +1006,38 @@ async def launch_builtin_model(
f"kwargs: {kwargs}"
)

async def _launch_one_model(
worker_ref, _replica_model_uid, rank: int, store_port: int
):
async def _launch_one_model(worker_ref, _replica_model_uid, rank: int):
if _replica_model_uid in self._replica_model_uid_to_worker:
raise ValueError(
f"Model is already in the model list, uid: {_replica_model_uid}"
)

nonlocal store_address
nonlocal store_port
xavier_config = (
{
"block_tracker_uid": self._block_tracker_mapping[model_uid].uid,
"block_tracker_address": self._block_tracker_mapping[
model_uid
].address,
"rank": rank,
"world_size": world_size,
"store_address": store_address,
"store_port": store_port,
}
if enable_xavier
else None
)

if enable_xavier and rank == 0:
rank0_address, _port = await worker_ref.launch_rank0_model(
_replica_model_uid, xavier_config
)
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
store_address = rank0_address.split(":")[0]
store_port = _port
return rank0_address

replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
nonlocal model_type

Expand All @@ -1014,37 +1057,36 @@ async def _launch_one_model(
gpu_idx=replica_gpu_idx,
download_hub=download_hub,
model_path=model_path,
xavier_config={
"block_tracker_address": self._block_tracker.address
if self._block_tracker is not None
else None,
"rank": rank,
"world_size": replica,
"store_address": self.address.split(":")[0],
"store_port": store_port,
}
if enable_xavier
else None,
xavier_config=xavier_config,
**kwargs,
)
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
return subpool_address

async def _launch_model():
try:
store_port = xo.utils.get_next_port()
worker_refs = []
rank_addresses = []
for rank, rep_model_uid in enumerate(
for _idx, rep_model_uid in enumerate(
iter_replica_model_uid(model_uid, replica)
):
worker_ref = (
target_ip_worker_ref
if target_ip_worker_ref is not None
else await self._choose_worker()
)
if enable_xavier and _idx == 0:
"""
Start the rank 0 model actor on the worker that holds the rank 1 replica,
solely for constructing the collective communication world.
"""
_uid = model_uid + "-rank0"
rank0_address = await _launch_one_model(worker_ref, _uid, 0)
worker_refs.append((worker_ref, _uid))
rank_addresses.append(rank0_address)

subpool_address = await _launch_one_model(
worker_ref, rep_model_uid, rank, store_port
worker_ref, rep_model_uid, _idx + 1
)
worker_refs.append((worker_ref, rep_model_uid))
rank_addresses.append(subpool_address)
Expand All @@ -1054,6 +1096,7 @@ async def _launch_model():
# because the transfer actor needs all the rank addresses used for collective communication
if enable_xavier:
logger.debug(f"Init transfer component for xavier...")
collective_manager_ref = self._collective_manager_mapping[model_uid]
tasks = []
for worker_ref, rep_model_uid in worker_refs:
tasks.append(
Expand All @@ -1064,6 +1107,13 @@ async def _launch_model():
# Here you must use asyncio.gather, not a for loop,
# or you will get stuck.
await asyncio.gather(*tasks)

# init collective_manager
for idx, addr in enumerate(rank_addresses):
await collective_manager_ref.register_rank(
idx, addr, update=False
)

logger.debug(f"Init transfer component for xavier done.")
except Exception:
# terminate_model will remove the replica info.
Expand Down Expand Up @@ -1193,6 +1243,38 @@ async def _terminate_one_model(_replica_model_uid):
raise
self._model_uid_to_replica_info.pop(model_uid, None)

# clear for xavier
rank0_uid = model_uid + "-rank0"
if rank0_uid in self._replica_model_uid_to_worker:
await _terminate_one_model(rank0_uid)

collective_manager_ref = self._collective_manager_mapping.pop(model_uid, None)
if collective_manager_ref is not None:
try:
await xo.destroy_actor(collective_manager_ref)
except Exception as e:
logger.debug(
"Destroy collective_manager_ref failed, model uid: %s, error: %s",
model_uid,
e,
)
finally:
logger.debug(
f"Destroy collective_manager_ref done. model uid: {model_uid}"
)
block_tracker_ref = self._block_tracker_mapping.pop(model_uid, None)
if block_tracker_ref is not None:
try:
await xo.destroy_actor(block_tracker_ref)
except Exception as e:
logger.debug(
"Destroy block_tracker_ref failed, model uid: %s, error: %s",
model_uid,
e,
)
finally:
logger.debug(f"Destroy block_tracker_ref done. model uid: {model_uid}")

@log_async(logger=logger)
async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
replica_info = self._model_uid_to_replica_info.get(model_uid, None)
Expand Down Expand Up @@ -1448,3 +1530,12 @@ def record_metrics(name, op, kwargs):

async def get_progress(self, request_id: str) -> float:
return await self._progress_tracker.get_progress(request_id)

async def call_collective_manager(
self, model_uid: str, func_name: str, *args, **kwargs
):
"""
Used by worker.
"""
collective_manager_ref = self._collective_manager_mapping[model_uid]
await getattr(collective_manager_ref, func_name)(*args, **kwargs)
Loading

0 comments on commit 121c08a

Please sign in to comment.