diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po b/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po index 09948143c0..ad02d3a579 100644 --- a/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po +++ b/doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: Xinference \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-01-10 14:44+0800\n" +"POT-Creation-Date: 2025-01-23 14:46+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Language-Team: LANGUAGE <LL@li.org>\n" @@ -31,9 +31,10 @@ msgid "" " instances. This allows KV cache computed by other replicas to be " "directly reused, avoiding redundant computations." msgstr "" -"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会影响整体吞吐量和单次推理的延迟。" -"Xinference 通过引入 ``Xavier`` 框架来增强 vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。" -"这使得其他副本计算出的 KV 缓存可以被直接重用,从而避免了冗余计算。" +"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会" +"影响整体吞吐量和单次推理的延迟。Xinference 通过引入 ``Xavier`` 框架来增强" +" vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。这使得其他副本计算出的 " +"KV 缓存可以被直接重用,从而避免了冗余计算。" #: ../../source/user_guide/vllm_enhancement.rst:15 msgid "Usage" @@ -43,8 +44,7 @@ msgstr "使用" msgid "" "Simply add the parameter ``enable_xavier=True`` when starting the vllm " "model." -msgstr "" -"启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。" +msgstr "启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。" #: ../../source/user_guide/vllm_enhancement.rst:20 msgid "Limitations" @@ -52,22 +52,14 @@ msgstr "限制" #: ../../source/user_guide/vllm_enhancement.rst:21 msgid "Xavier requires vllm version >= ``0.6.5``." -msgstr "" -"Xavier 要求 vllm 版本不低于 ``0.6.5`` 。" +msgstr "Xavier 要求 vllm 版本不低于 ``0.6.5`` 。" #: ../../source/user_guide/vllm_enhancement.rst:22 msgid "" -"Xavier is currently not compatible with model reloading after CUDA OOM in" -" Xinference. (it will be supported in the future)" -msgstr "" -"目前 Xavier 与 Xinference 中模型 CUDA OOM 后的重新拉起特性不兼容(未来将解决此问题)。" - -#: ../../source/user_guide/vllm_enhancement.rst:23 -msgid "" "Due to the underlying communication not recognizing ``0.0.0.0``, the " "actual IP address needs to be passed when starting Xinference, for " "example: ``xinference-local -H 192.168.xx.xx``." msgstr "" -"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 IP 地址," -"例如:``xinference-local -H 192.168.xx.xx`` 。" +"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 " +"IP 地址,例如:``xinference-local -H 192.168.xx.xx`` 。" diff --git a/doc/source/user_guide/vllm_enhancement.rst b/doc/source/user_guide/vllm_enhancement.rst index 1446b7c9d4..e175449fd7 100644 --- a/doc/source/user_guide/vllm_enhancement.rst +++ b/doc/source/user_guide/vllm_enhancement.rst @@ -19,5 +19,4 @@ Simply add the parameter ``enable_xavier=True`` when starting the vllm model. Limitations *********** * Xavier requires vllm version >= ``0.6.5``. -* Xavier is currently not compatible with model reloading after CUDA OOM in Xinference. (it will be supported in the future) * Due to the underlying communication not recognizing ``0.0.0.0``, the actual IP address needs to be passed when starting Xinference, for example: ``xinference-local -H 192.168.xx.xx``. diff --git a/xinference/core/model.py b/xinference/core/model.py index 2c846081b4..db30bd4645 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -35,6 +35,7 @@ List, Optional, Union, + no_type_check, ) import sse_starlette.sse @@ -302,6 +303,7 @@ def __repr__(self) -> str: def decrease_serve_count(self): self._serve_count -= 1 + @no_type_check async def start_transfer_for_vllm(self, rank_addresses: List[str]): from ..model.llm.vllm.core import VLLMModel from ..model.llm.vllm.xavier.transfer import TransferActor diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 5235b2a15d..b890f46f88 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -268,8 +268,12 @@ async def signal_handler(): ) from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker + from ..model.llm.vllm.xavier.collective_manager import CollectiveManager - self._block_tracker: Optional[xo.ActorRefType[VLLMBlockTracker]] = None + self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {} + self._collective_manager_mapping: Dict[ + str, xo.ActorRefType[CollectiveManager] + ] = {} @typing.no_type_check async def get_cluster_device_info(self, detailed: bool = False) -> List: @@ -960,26 +964,40 @@ async def launch_builtin_model( ]: raise ValueError("Tensorizer is not supported for %s." % model_name) + if model_uid is None: + model_uid = self._gen_model_uid(model_name) + + # Xavier-related enable_xavier: bool = ( bool(kwargs.pop("enable_xavier", False)) and model_engine is not None and model_engine.lower() == "vllm" ) + store_address = None + store_port = None + world_size = None if enable_xavier: if replica <= 1: logger.warning(f"Enabling xavier when `replica<=1` is meaningless.") enable_xavier = False else: from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker + from ..model.llm.vllm.xavier.collective_manager import CollectiveManager - self._block_tracker = await xo.create_actor( + self._block_tracker_mapping[model_uid] = await xo.create_actor( VLLMBlockTracker, address=self.address, - uid=VLLMBlockTracker.default_uid(), + uid=f"{VLLMBlockTracker.default_uid()}-{model_uid}", ) - - if model_uid is None: - model_uid = self._gen_model_uid(model_name) + world_size = replica + 1 + logger.info(f"Going to start xavier with world size: {world_size}") + self._collective_manager_mapping[model_uid] = await xo.create_actor( + CollectiveManager, + address=self.address, + uid=f"{CollectiveManager.default_uid()}-{model_uid}", + model_uid=model_uid, + ) + logger.info(f"Start collective manager for {model_uid} done.") model_size = str(model_size_in_billions) if model_size_in_billions else "" logger.debug( @@ -988,13 +1006,38 @@ async def launch_builtin_model( f"kwargs: {kwargs}" ) - async def _launch_one_model( - worker_ref, _replica_model_uid, rank: int, store_port: int - ): + async def _launch_one_model(worker_ref, _replica_model_uid, rank: int): if _replica_model_uid in self._replica_model_uid_to_worker: raise ValueError( f"Model is already in the model list, uid: {_replica_model_uid}" ) + + nonlocal store_address + nonlocal store_port + xavier_config = ( + { + "block_tracker_uid": self._block_tracker_mapping[model_uid].uid, + "block_tracker_address": self._block_tracker_mapping[ + model_uid + ].address, + "rank": rank, + "world_size": world_size, + "store_address": store_address, + "store_port": store_port, + } + if enable_xavier + else None + ) + + if enable_xavier and rank == 0: + rank0_address, _port = await worker_ref.launch_rank0_model( + _replica_model_uid, xavier_config + ) + self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref + store_address = rank0_address.split(":")[0] + store_port = _port + return rank0_address + replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx) nonlocal model_type @@ -1014,17 +1057,7 @@ async def _launch_one_model( gpu_idx=replica_gpu_idx, download_hub=download_hub, model_path=model_path, - xavier_config={ - "block_tracker_address": self._block_tracker.address - if self._block_tracker is not None - else None, - "rank": rank, - "world_size": replica, - "store_address": self.address.split(":")[0], - "store_port": store_port, - } - if enable_xavier - else None, + xavier_config=xavier_config, **kwargs, ) self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref @@ -1032,10 +1065,9 @@ async def _launch_one_model( async def _launch_model(): try: - store_port = xo.utils.get_next_port() worker_refs = [] rank_addresses = [] - for rank, rep_model_uid in enumerate( + for _idx, rep_model_uid in enumerate( iter_replica_model_uid(model_uid, replica) ): worker_ref = ( @@ -1043,8 +1075,18 @@ async def _launch_model(): if target_ip_worker_ref is not None else await self._choose_worker() ) + if enable_xavier and _idx == 0: + """ + Start the rank 0 model actor on the worker that holds the rank 1 replica, + solely for constructing the collective communication world. + """ + _uid = model_uid + "-rank0" + rank0_address = await _launch_one_model(worker_ref, _uid, 0) + worker_refs.append((worker_ref, _uid)) + rank_addresses.append(rank0_address) + subpool_address = await _launch_one_model( - worker_ref, rep_model_uid, rank, store_port + worker_ref, rep_model_uid, _idx + 1 ) worker_refs.append((worker_ref, rep_model_uid)) rank_addresses.append(subpool_address) @@ -1054,6 +1096,7 @@ async def _launch_model(): # because the transfer actor needs all the rank addresses used for collective communication if enable_xavier: logger.debug(f"Init transfer component for xavier...") + collective_manager_ref = self._collective_manager_mapping[model_uid] tasks = [] for worker_ref, rep_model_uid in worker_refs: tasks.append( @@ -1064,6 +1107,13 @@ async def _launch_model(): # Here you must use asyncio.gather, not a for loop, # or you will get stuck. await asyncio.gather(*tasks) + + # init collective_manager + for idx, addr in enumerate(rank_addresses): + await collective_manager_ref.register_rank( + idx, addr, update=False + ) + logger.debug(f"Init transfer component for xavier done.") except Exception: # terminate_model will remove the replica info. @@ -1193,6 +1243,38 @@ async def _terminate_one_model(_replica_model_uid): raise self._model_uid_to_replica_info.pop(model_uid, None) + # clear for xavier + rank0_uid = model_uid + "-rank0" + if rank0_uid in self._replica_model_uid_to_worker: + await _terminate_one_model(rank0_uid) + + collective_manager_ref = self._collective_manager_mapping.pop(model_uid, None) + if collective_manager_ref is not None: + try: + await xo.destroy_actor(collective_manager_ref) + except Exception as e: + logger.debug( + "Destroy collective_manager_ref failed, model uid: %s, error: %s", + model_uid, + e, + ) + finally: + logger.debug( + f"Destroy collective_manager_ref done. model uid: {model_uid}" + ) + block_tracker_ref = self._block_tracker_mapping.pop(model_uid, None) + if block_tracker_ref is not None: + try: + await xo.destroy_actor(block_tracker_ref) + except Exception as e: + logger.debug( + "Destroy block_tracker_ref failed, model uid: %s, error: %s", + model_uid, + e, + ) + finally: + logger.debug(f"Destroy block_tracker_ref done. model uid: {model_uid}") + @log_async(logger=logger) async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]: replica_info = self._model_uid_to_replica_info.get(model_uid, None) @@ -1448,3 +1530,12 @@ def record_metrics(name, op, kwargs): async def get_progress(self, request_id: str) -> float: return await self._progress_tracker.get_progress(request_id) + + async def call_collective_manager( + self, model_uid: str, func_name: str, *args, **kwargs + ): + """ + Used by worker. + """ + collective_manager_ref = self._collective_manager_mapping[model_uid] + await getattr(collective_manager_ref, func_name)(*args, **kwargs) diff --git a/xinference/core/worker.py b/xinference/core/worker.py index d0a5f5ffe6..f50a8276fa 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -24,7 +24,7 @@ from collections import defaultdict from dataclasses import dataclass from logging import getLogger -from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, no_type_check import xoscar as xo from async_timeout import timeout @@ -184,12 +184,12 @@ async def recover_sub_pool(self, address): self._model_uid_to_recover_count[model_uid] = ( recover_count - 1 ) - await self.launch_builtin_model(**launch_args) + await self.recover_model(launch_args) else: logger.warning("Stop recreating model actor.") else: logger.warning("Recreating model actor %s ...", model_uid) - await self.launch_builtin_model(**launch_args) + await self.recover_model(launch_args) break @classmethod @@ -940,7 +940,11 @@ async def terminate_model(self, model_uid: str, is_model_die=False): # Terminate model while its launching is not allow if model_uid in self._model_uid_launching_guard: raise ValueError(f"{model_uid} is launching") - origin_uid, _ = parse_replica_model_uid(model_uid) + # In special cases, if the suffix is `-rank0`, this is the Xavier's rank 0 model actor. + if model_uid.endswith("-rank0"): + origin_uid = model_uid.removesuffix("-rank0") + else: + origin_uid, _ = parse_replica_model_uid(model_uid) try: _ = await self.get_supervisor_ref() if self._event_collector_ref is not None: @@ -1173,3 +1177,65 @@ async def start_transfer_for_vllm( ): model_ref = self._model_uid_to_model[rep_model_uid] await model_ref.start_transfer_for_vllm(rank_addresses) + + @log_async(logger=logger, level=logging.INFO) + async def launch_rank0_model( + self, rep_model_uid: str, xavier_config: Dict[str, Any] + ) -> Tuple[str, int]: + from ..model.llm.vllm.xavier.collective_manager import Rank0ModelActor + + if os.name != "nt" and platform.system() != "Darwin": + # Linux + start_method = "forkserver" + else: + # Windows and macOS + start_method = "spawn" + subpool_address = await self._main_pool.append_sub_pool( + start_method=start_method + ) + + store_address = subpool_address.split(":")[0] + # Note that `store_port` needs to be generated on the worker, + # as the TCP store is on rank 0, not on the supervisor. + store_port = xo.utils.get_next_port() + self._model_uid_launching_guard[rep_model_uid] = True + try: + try: + xavier_config["rank_address"] = subpool_address + xavier_config["store_address"] = store_address + xavier_config["store_port"] = store_port + model_ref = await xo.create_actor( + Rank0ModelActor, + address=subpool_address, + uid=rep_model_uid, + xavier_config=xavier_config, + ) + except: + await self._main_pool.remove_sub_pool(subpool_address) + raise + self._model_uid_to_model[rep_model_uid] = model_ref + self._model_uid_to_addr[rep_model_uid] = subpool_address + finally: + del self._model_uid_launching_guard[rep_model_uid] + return subpool_address, store_port + + @no_type_check + async def recover_model(self, launch_args: Dict[str, Any]): + rep_model_uid = launch_args.get("model_uid") + origin_uid, _ = parse_replica_model_uid(rep_model_uid) + xavier_config: Optional[Dict[str, Any]] = launch_args.get("xavier_config", None) + is_xavier: bool = xavier_config is not None + supervisor_ref = await self.get_supervisor_ref(add_worker=False) + if is_xavier: + rank = xavier_config.get("rank") + await supervisor_ref.call_collective_manager( + origin_uid, "unregister_rank", rank + ) + subpool_address = await self.launch_builtin_model(**launch_args) + if is_xavier: + model_ref = self._model_uid_to_model[rep_model_uid] + await model_ref.start_transfer_for_vllm([]) + rank = xavier_config.get("rank") + await supervisor_ref.call_collective_manager( + origin_uid, "register_rank", rank, subpool_address, update=True + ) diff --git a/xinference/model/llm/vllm/xavier/block.py b/xinference/model/llm/vllm/xavier/block.py index 8aa2c991f5..1b43d83ae7 100644 --- a/xinference/model/llm/vllm/xavier/block.py +++ b/xinference/model/llm/vllm/xavier/block.py @@ -76,12 +76,11 @@ def xavier_config(self, v: Dict[str, Any]): self._xavier_config = v async def _get_block_tracker_ref(self): - from .block_tracker import VLLMBlockTracker - if self._block_tracker_ref is None: block_tracker_address = self.xavier_config.get("block_tracker_address") + block_tracker_uid = self.xavier_config.get("block_tracker_uid") self._block_tracker_ref = await xo.actor_ref( - address=block_tracker_address, uid=VLLMBlockTracker.default_uid() + address=block_tracker_address, uid=block_tracker_uid ) return self._block_tracker_ref @@ -90,7 +89,7 @@ async def unregister_block(self, block_id: int): tracker_ref = await self._get_block_tracker_ref() await tracker_ref.unregister_block( self.xavier_config.get("virtual_engine"), - self.xavier_config.get("rank_address"), + self.xavier_config.get("rank"), block_id, ) diff --git a/xinference/model/llm/vllm/xavier/block_tracker.py b/xinference/model/llm/vllm/xavier/block_tracker.py index dc50ca1b6a..a5789b4a46 100644 --- a/xinference/model/llm/vllm/xavier/block_tracker.py +++ b/xinference/model/llm/vllm/xavier/block_tracker.py @@ -24,81 +24,75 @@ def default_uid(cls): def __init__(self): super().__init__() - # engine -> hash_to_address_and_block_id - self._hash_to_address_and_block_id: Dict[ - int, Dict[int, Set[Tuple[str, int]]] - ] = {} - # engine -> address_to_hash_and_block_id - self._address_to_hash_and_block_id: Dict[ - int, Dict[str, Set[Tuple[int, int]]] - ] = {} + # engine -> hash -> (rank, block_id) + self._hash_to_rank_and_block_id: Dict[int, Dict[int, Set[Tuple[int, int]]]] = {} + # engine -> rank -> (hash, block_id) + self._rank_to_hash_and_block_id: Dict[int, Dict[int, Set[Tuple[int, int]]]] = {} + self._unavailable_ranks: Set[int] = set() def register_blocks( - self, virtual_engine: int, block_infos: List[Tuple[int, int]], address: str + self, virtual_engine: int, block_infos: List[Tuple[int, int]], rank: int ): # Update query meta - if virtual_engine not in self._hash_to_address_and_block_id: - self._hash_to_address_and_block_id[virtual_engine] = {} - hash_to_address_and_block_id = self._hash_to_address_and_block_id[ - virtual_engine - ] + if virtual_engine not in self._hash_to_rank_and_block_id: + self._hash_to_rank_and_block_id[virtual_engine] = {} + hash_to_rank_and_block_id = self._hash_to_rank_and_block_id[virtual_engine] for hash_content, block_id in block_infos: - if hash_content not in hash_to_address_and_block_id: - hash_to_address_and_block_id[hash_content] = { - (address, block_id), + if hash_content not in hash_to_rank_and_block_id: + hash_to_rank_and_block_id[hash_content] = { + (rank, block_id), } else: - hash_to_address_and_block_id[hash_content].add((address, block_id)) + hash_to_rank_and_block_id[hash_content].add((rank, block_id)) # Update remove meta - if virtual_engine not in self._address_to_hash_and_block_id: - self._address_to_hash_and_block_id[virtual_engine] = {} - address_to_hash_and_block_id = self._address_to_hash_and_block_id[ - virtual_engine - ] - if address not in address_to_hash_and_block_id: - address_to_hash_and_block_id[address] = set() - address_to_hash_and_block_id[address].update(block_infos) + if virtual_engine not in self._rank_to_hash_and_block_id: + self._rank_to_hash_and_block_id[virtual_engine] = {} + rank_to_hash_and_block_id = self._rank_to_hash_and_block_id[virtual_engine] + if rank not in rank_to_hash_and_block_id: + rank_to_hash_and_block_id[rank] = set() + rank_to_hash_and_block_id[rank].update(block_infos) def query_blocks( self, virtual_engine: int, hash_contents: List[Tuple[int, int]] - ) -> Dict[str, Set[Tuple[int, int, int]]]: - if virtual_engine not in self._hash_to_address_and_block_id: + ) -> Dict[int, Set[Tuple[int, int, int]]]: + if virtual_engine not in self._hash_to_rank_and_block_id: return {} - hash_to_address_and_block_id = self._hash_to_address_and_block_id[ - virtual_engine - ] - remote: Dict[str, Set[Tuple[int, int, int]]] = {} + hash_to_rank_and_block_id = self._hash_to_rank_and_block_id[virtual_engine] + remote: Dict[int, Set[Tuple[int, int, int]]] = {} for hash_content, _id in hash_contents: if ( - hash_content in hash_to_address_and_block_id - ) and hash_to_address_and_block_id[hash_content]: - # TODO: Randomly select here, and try to distribute requests as evenly as possible. - # There may be better methods in the future. - address, block_id = random.choice( - list(hash_to_address_and_block_id[hash_content]) - ) - if address not in remote: - remote[address] = { - (hash_content, block_id, _id), - } - else: - remote[address].add((hash_content, block_id, _id)) + hash_content in hash_to_rank_and_block_id + ) and hash_to_rank_and_block_id[hash_content]: + # exclude ranks that are in the recovery process + rank_and_block_id = [ + (r, b) + for r, b in hash_to_rank_and_block_id[hash_content] + if r not in self._unavailable_ranks + ] + if rank_and_block_id: + # TODO: Randomly select here, and try to distribute requests as evenly as possible. + # There may be better methods in the future. + rank, block_id = random.choice(rank_and_block_id) + if rank not in remote: + remote[rank] = { + (hash_content, block_id, _id), + } + else: + remote[rank].add((hash_content, block_id, _id)) return remote - def unregister_block(self, virtual_engine: int, address: str, block_id: int): - if (virtual_engine not in self._address_to_hash_and_block_id) or ( - virtual_engine not in self._hash_to_address_and_block_id + def unregister_block(self, virtual_engine: int, rank: int, block_id: int): + if (virtual_engine not in self._rank_to_hash_and_block_id) or ( + virtual_engine not in self._hash_to_rank_and_block_id ): return # Update remove meta - address_to_hash_and_block_id = self._address_to_hash_and_block_id[ - virtual_engine - ] - if address not in address_to_hash_and_block_id: + rank_to_hash_and_block_id = self._rank_to_hash_and_block_id[virtual_engine] + if rank not in rank_to_hash_and_block_id: return - hash_and_block_id = address_to_hash_and_block_id[address] + hash_and_block_id = rank_to_hash_and_block_id[rank] detail: Optional[Tuple[int, int]] = None for hash_content, _id in hash_and_block_id.copy(): if _id == block_id: @@ -108,9 +102,28 @@ def unregister_block(self, virtual_engine: int, address: str, block_id: int): # Update query meta if detail is not None: - hash_to_address_and_block_id = self._hash_to_address_and_block_id[ - virtual_engine - ] + hash_to_rank_and_block_id = self._hash_to_rank_and_block_id[virtual_engine] _hash = detail[0] - if _hash in hash_to_address_and_block_id: - hash_to_address_and_block_id[_hash].discard((address, detail[1])) + if _hash in hash_to_rank_and_block_id: + hash_to_rank_and_block_id[_hash].discard((rank, detail[1])) + + def unregister_rank(self, rank: int): + """ + This rank is in the recovery process, and its query results will be excluded. + """ + self._unavailable_ranks.add(rank) + + def register_rank(self, rank: int): + """ + After recovery is successful, clear all stale data of the rank and mark the rank as available. + """ + for _, rank_to_hash_and_block_id in self._rank_to_hash_and_block_id.items(): + rank_to_hash_and_block_id.pop(rank, None) + + for _, hash_to_rank_and_block_id in self._hash_to_rank_and_block_id.items(): + for _, rank_and_block_id in hash_to_rank_and_block_id.items(): + to_delete = [(r, b) for r, b in rank_and_block_id if r == rank] + if to_delete: + rank_and_block_id.difference_update(to_delete) + + self._unavailable_ranks.discard(rank) diff --git a/xinference/model/llm/vllm/xavier/collective.py b/xinference/model/llm/vllm/xavier/collective.py new file mode 100644 index 0000000000..75b4c4df80 --- /dev/null +++ b/xinference/model/llm/vllm/xavier/collective.py @@ -0,0 +1,74 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class CollectiveRank: + def __init__( + self, + rank: int, + world_size: int, + rank_address: str, + store_address: str, + store_port: int, + world_addresses: List[str], + ): + self._rank = rank + self._world_size = world_size + self._rank_address = rank_address + self._world_addresses = world_addresses + self._store_address = store_address + self._store_port = store_port + self._device = None + self._tcp_store = None + self._context = None + + def init_rank(self): + from xoscar.collective import xoscar_pygloo as xp + + self._context = xp.rendezvous.Context(self._rank, self._world_size) + + attr = xp.transport.tcp.attr(self._rank_address.split(":")[0]) + self._device = xp.transport.tcp.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = self._store_port + opt.numWorkers = self._world_size + opt.isServer = self._rank == 0 + opt.waitWorkers = False + + self._tcp_store = xp.rendezvous.TCPStore(self._store_address, opt) + if self._world_addresses: + self.connect_full_mesh() + + def connect_full_mesh( + self, prefix: Optional[str] = None, world_addresses: Optional[List[str]] = None + ): + from xoscar.collective import xoscar_pygloo as xp + + assert self._device is not None + assert self._tcp_store is not None + assert self._context is not None + if world_addresses is not None: + self._world_addresses = world_addresses + prefix_store = xp.rendezvous.PrefixStore( + prefix or str(self._world_size), self._tcp_store + ) + self._context.connectFullMesh(prefix_store, self._device) + logger.debug( + f"Rank {self._rank} arrives successfully, world addresses: {self._world_addresses}" + ) diff --git a/xinference/model/llm/vllm/xavier/collective_manager.py b/xinference/model/llm/vllm/xavier/collective_manager.py new file mode 100644 index 0000000000..a0be319fe4 --- /dev/null +++ b/xinference/model/llm/vllm/xavier/collective_manager.py @@ -0,0 +1,147 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Optional, no_type_check + +import xoscar as xo + +from .block_tracker import VLLMBlockTracker + +if TYPE_CHECKING: + from .transfer import Rank0TransferActor, TransferActor + + +logger = logging.getLogger(__name__) + + +class Rank0ModelActor(xo.StatelessActor): + @classmethod + def default_uid(cls): + return "rank0-model-actor" + + def __init__(self, xavier_config: Dict[str, Any]): + super().__init__() + self._rank = 0 + self._xavier_config = xavier_config + self._transfer_ref: Optional[xo.ActorRefType["Rank0TransferActor"]] = None + + async def __pre_destroy__(self): + if self._transfer_ref is not None: + try: + await xo.destroy_actor(self._transfer_ref) + del self._transfer_ref + except Exception as e: + logger.debug( + f"Destroy transfer actor failed, rank: {self._rank}, address: {self.address}, error: {e}" + ) + + @no_type_check + async def start_transfer_for_vllm(self, rank_addresses: List[str]): + from .transfer import Rank0TransferActor + + self._transfer_ref = await xo.create_actor( + Rank0TransferActor, + address=self.address, + uid=f"{Rank0TransferActor.default_uid()}-{self._rank}", + rank=self._rank, + world_size=self._xavier_config.get("world_size"), # type: ignore + rank_address=self._xavier_config.get("rank_address"), # type: ignore + store_address=self._xavier_config.get("store_address"), # type: ignore + store_port=self._xavier_config.get("store_port"), # type: ignore + world_addresses=rank_addresses, + ) + logger.debug( + f"Init transfer actor: {self._transfer_ref.address}, rank: {self._rank} done for vllm." # type: ignore + ) + + +def with_lock(method): + async def wrapper(self, *args, **kwargs): + async with self._lock: + return await method(self, *args, **kwargs) + + return wrapper + + +class CollectiveManager(xo.StatelessActor): + @classmethod + def default_uid(cls): + return f"xavier-collective-manager" + + def __init__(self, model_uid: str): + super().__init__() + self._model_uid = model_uid + self._tracker_ref: Optional[xo.ActorRefType["VLLMBlockTracker"]] = None + self._rank_to_ref: Dict[int, xo.ActorRefType["TransferActor"]] = {} + self._lock = asyncio.Lock() + + async def __post_create__(self): + self._tracker_ref = await xo.actor_ref( + address=self.address, + uid=f"{VLLMBlockTracker.default_uid()}-{self._model_uid}", + ) + + async def unregister_rank(self, rank: int): + self._rank_to_ref.pop(rank, None) + await self._tracker_ref.unregister_rank(rank) # type: ignore + logger.debug(f"Unregister rank: {rank}") + + async def register_rank(self, rank: int, address: str, update: bool = False): + from .transfer import TransferActor + + rank_ref = await xo.actor_ref( + address=address, uid=f"{TransferActor.default_uid()}-{rank}" + ) + self._rank_to_ref[rank] = rank_ref + logger.debug(f"Register rank: {rank}, address: {address}") + if update: + await self._update_world() + await self._tracker_ref.register_rank(rank) # type: ignore + + @with_lock + async def _update_world(self): + """ + Locking is used to prevent chaos when multiple replicas trigger recovery simultaneously. + """ + from .....core.utils import gen_random_string + + prefix = gen_random_string(6) + tasks = [] + rank_to_ref = self._rank_to_ref.copy() + world_addresses = [ref.address for _, ref in sorted(rank_to_ref.items())] + for rank, ref in rank_to_ref.items(): + tasks.append(ref.connect_full_mesh(prefix, world_addresses)) + try: + logger.debug( + f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix}" + ) + await asyncio.gather(*tasks) + logger.debug( + f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix} done." + ) + except Exception as e: + """ + The exception here is most likely due to another replica triggering recovery during the recovery process, + causing `connect_full_mesh` to time out. + Simply log the exception and + let the subsequent update process handle the reconstruction of the collective communication world. + """ + logger.error( + f"Rebuild collective communication with world_addresses: {world_addresses} failed. " + f"Exception: {e}" + ) + # Print the complete error stack + traceback.print_exception(type(e), e, e.__traceback__) diff --git a/xinference/model/llm/vllm/xavier/executor.py b/xinference/model/llm/vllm/xavier/executor.py index 6b6e80a7b2..4d4ca53b06 100644 --- a/xinference/model/llm/vllm/xavier/executor.py +++ b/xinference/model/llm/vllm/xavier/executor.py @@ -64,14 +64,13 @@ async def init_transfer(self): ) async def _get_block_tracker_ref(self): - from .block_tracker import VLLMBlockTracker - if self._block_tracker_ref is None: block_tracker_address = self.vllm_config.xavier_config.get( "block_tracker_address" ) + block_tracker_uid = self.vllm_config.xavier_config.get("block_tracker_uid") self._block_tracker_ref = await xo.actor_ref( - address=block_tracker_address, uid=VLLMBlockTracker.default_uid() + address=block_tracker_address, uid=block_tracker_uid ) return self._block_tracker_ref @@ -86,8 +85,8 @@ async def _get_transfer_ref(self): ) return self._transfer_ref - def get_rank_address(self) -> str: - return self.vllm_config.xavier_config.get("rank_address") + def get_rank(self) -> int: + return self.vllm_config.xavier_config.get("rank") async def execute_model_async( self, @@ -100,7 +99,7 @@ async def execute_model_async( virtual_engine = execute_model_req.virtual_engine block_tracker_ref = await self._get_block_tracker_ref() scheduler = self.scheduler[virtual_engine] # type: ignore - rank_address = self.get_rank_address() + rank = self.get_rank() executed_blocks_details: Set[Tuple[int, int]] = set() for meta in execute_model_req.seq_group_metadata_list: block_tables = meta.block_tables @@ -117,16 +116,19 @@ async def execute_model_async( res = await super().execute_model_async(execute_model_req) - """ - Why not collect and register the information after execution? - Because after execution, the model's execution callback hook will release the block_id, - causing the block manager to lose access to the correct information. - """ - await block_tracker_ref.register_blocks( - virtual_engine, list(executed_blocks_details), rank_address - ) + if executed_blocks_details: + """ + Why not collect and register the information after execution? + Because after execution, the model's execution callback hook will release the block_id, + causing the block manager to lose access to the correct information. + """ + await block_tracker_ref.register_blocks( + virtual_engine, list(executed_blocks_details), rank + ) - for _, _id in executed_blocks_details: - scheduler.block_manager.set_block_status_by_block_id("executed", _id, True) + for _, _id in executed_blocks_details: + scheduler.block_manager.set_block_status_by_block_id( + "executed", _id, True + ) return res diff --git a/xinference/model/llm/vllm/xavier/scheduler.py b/xinference/model/llm/vllm/xavier/scheduler.py index 4bb6a80e60..4c4fd5bbd4 100644 --- a/xinference/model/llm/vllm/xavier/scheduler.py +++ b/xinference/model/llm/vllm/xavier/scheduler.py @@ -72,12 +72,11 @@ def __init__( self._transfer_status: Dict[SequenceGroup, Set[int]] = {} async def _get_block_tracker_ref(self): - from .block_tracker import VLLMBlockTracker - if self._block_tracker_ref is None: block_tracker_address = self._xavier_config.get("block_tracker_address") + block_tracker_uid = self._xavier_config.get("block_tracker_uid") self._block_tracker_ref = await xo.actor_ref( - address=block_tracker_address, uid=VLLMBlockTracker.default_uid() + address=block_tracker_address, uid=block_tracker_uid ) return self._block_tracker_ref @@ -97,7 +96,12 @@ async def _get_transfer_details( virtual_engine: int, block_tables: Dict[int, List[int]], seq_group: SequenceGroup, - ) -> Tuple[Set[int], Dict[str, Set[Tuple[int, int, int]]]]: + ) -> Tuple[Set[int], Dict[int, Set[Tuple[int, int, int]]]]: + # If the `seq_group` has the `force_calculation` attribute set to `True`, + # it indicates that there were issues during the transmission process. + # In this case, force the computation and exclude it from the Xavier process. + if getattr(seq_group, "force_calculation", False): + return set(), dict() """ Retrieve information from other replicas to check if any blocks have already been computed, for the purpose of data transfer. @@ -132,48 +136,63 @@ async def _get_transfer_details( ) ): details.add(detail) - tracker_ref = await self._get_block_tracker_ref() - remote = await tracker_ref.query_blocks(virtual_engine, list(details)) - # Not all queried blocks have corresponding results in other replicas. - # Therefore, it is necessary to record which local block data was actually transferred. - local: Set[int] = set() - for _, remote_details in remote.items(): - for _, _, local_block_id in remote_details: - local.add(local_block_id) - if local: - logger.debug( - f"Data in local blocks: {local} will be transmitted from the remote." - ) - return local, remote + + if details: + tracker_ref = await self._get_block_tracker_ref() + remote = await tracker_ref.query_blocks(virtual_engine, list(details)) + # Not all queried blocks have corresponding results in other replicas. + # Therefore, it is necessary to record which local block data was actually transferred. + local: Set[int] = set() + for _, remote_details in remote.items(): + for _, _, local_block_id in remote_details: + local.add(local_block_id) + if local: + logger.debug( + f"Data in local blocks: {local} will be transmitted from the remote." + ) + return local, remote + else: + return set(), dict() async def _do_transfer_inner( - self, virtual_engine: int, remote: Dict[str, Set[Tuple[int, int, int]]] + self, virtual_engine: int, remote: Dict[int, Set[Tuple[int, int, int]]] ): transfer_ref = await self._get_transfer_ref() - for addr, hash_and_block_id in remote.items(): + for from_rank, hash_and_block_id in remote.items(): src_to_dst: Dict[int, int] = {x[1]: x[2] for x in hash_and_block_id} - await transfer_ref.recv(virtual_engine, addr, src_to_dst) + await transfer_ref.recv(virtual_engine, from_rank, src_to_dst) async def _do_transfer( self, virtual_engine: int, local: Set[int], - remote: Dict[str, Set[Tuple[int, int, int]]], + remote: Dict[int, Set[Tuple[int, int, int]]], seq_group: SequenceGroup, - is_prefill: bool, ): - await self._do_transfer_inner(virtual_engine, remote) - # After the transfer is completed, update the corresponding metadata. - self._transfer_status[seq_group] = local - for _id in local: - self.block_manager.set_block_status_by_block_id("transferred", _id, True) - # After the transfer, place the `seq_group` back into the appropriate queue to - # wait for the next scheduling execution. - if is_prefill: + try: + await self._do_transfer_inner(virtual_engine, remote) + except Exception as e: + """ + The exception here is most likely due to the sender triggering recovery during the transmission process. + In this case, fallback to performing computation during the prefill stage. + """ + logger.error(f"Transfer failed: {e}") + # Force this `seq_group` to perform computation. + seq_group.force_calculation = True + self._transfer_status.pop(seq_group, None) self.waiting.appendleft(seq_group) + self._transferring.remove(seq_group) else: - self.running.appendleft(seq_group) - self._transferring.remove(seq_group) + # After the transfer is completed, update the corresponding metadata. + self._transfer_status[seq_group] = local + for _id in local: + self.block_manager.set_block_status_by_block_id( + "transferred", _id, True + ) + # After the transfer, place the `seq_group` back into the `waiting` queue to + # wait for the next scheduling execution. + self.waiting.appendleft(seq_group) + self._transferring.remove(seq_group) @no_type_check async def schedule( @@ -240,39 +259,36 @@ async def schedule( After completing the scheduling, the blocks have been allocated. Therefore, it is possible to check whether some blocks have already been computed on other replicas based on this information, and subsequently initiate the transfer. + According to the internal code comments in vllm, + whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage. + It is noted that data transmission is only applied during the prefill stage. + In the decode stage, it only applies to the last token of the block, which can negatively impact throughput. """ - local, remote = await self._get_transfer_details( - virtual_engine, block_tables, seq_group - ) - if remote: - running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - # According to the internal code comments in vllm, - # whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage. - is_prefill = token_chunk_size != 1 - for seq in running_seqs: - seq.status = ( - SequenceStatus.WAITING if is_prefill else SequenceStatus.RUNNING - ) - # Additional attribute `transferred` to mark that this `seq_group` involves a transfer process. - # During the next scheduling, block allocation will no longer be required - # since it has already been completed. - seq.transferred = True - seq.data._stage = ( - SequenceStage.PREFILL if is_prefill else SequenceStage.DECODE - ) - self._transfer_status[seq_group] = set() - # Use `create_task` to avoid blocking subsequent scheduling. - asyncio.create_task( - self._do_transfer( - virtual_engine, local, remote, seq_group, is_prefill - ) + is_prefill: bool = token_chunk_size != 1 + if is_prefill: + local, remote = await self._get_transfer_details( + virtual_engine, block_tables, seq_group ) - # The `seq_group` that is currently being transferred enters a new queue. - self._transferring.append(seq_group) - has_transferring = True - continue - else: - scheduled_seq_groups.append(seq_group) + if remote: + running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + for seq in running_seqs: + seq.status = SequenceStatus.WAITING + # Additional attribute `transferred` to mark that this `seq_group` involves a transfer process. + # During the next scheduling, block allocation will no longer be required + # since it has already been completed. + seq.transferred = True + seq.data._stage = SequenceStage.PREFILL + self._transfer_status[seq_group] = set() + # Use `create_task` to avoid blocking subsequent scheduling. + asyncio.create_task( + self._do_transfer(virtual_engine, local, remote, seq_group) + ) + # The `seq_group` that is currently being transferred enters a new queue. + self._transferring.append(seq_group) + has_transferring = True + continue + else: + scheduled_seq_groups.append(seq_group) if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( diff --git a/xinference/model/llm/vllm/xavier/test/test_xavier.py b/xinference/model/llm/vllm/xavier/test/test_xavier.py index 0caf3c2b7f..6a407006cb 100644 --- a/xinference/model/llm/vllm/xavier/test/test_xavier.py +++ b/xinference/model/llm/vllm/xavier/test/test_xavier.py @@ -21,11 +21,11 @@ class ExtendedBlockTracker(VLLMBlockTracker): - def get_hash_to_address_and_block_id(self): - return self._hash_to_address_and_block_id + def get_hash_to_rank_and_block_id(self): + return self._hash_to_rank_and_block_id - def get_address_to_hash_and_block_id(self): - return self._address_to_hash_and_block_id + def get_rank_to_hash_and_block_id(self): + return self._rank_to_hash_and_block_id @pytest.fixture @@ -53,53 +53,54 @@ async def test_block_tracker(actor_pool_context): ) virtual_engine = 0 + rank = 0 block_infos = [(123, 0), (456, 1), (789, 2)] # register blocks - await tracker_ref.register_blocks(virtual_engine, block_infos, addr) + await tracker_ref.register_blocks(virtual_engine, block_infos, rank) # query blocks res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5)]) assert len(res) == 1 - assert addr in res - assert len(res[addr]) == 2 - assert {x[0] for x in res[addr]} == {123, 789} - assert {x[1] for x in res[addr]} == {0, 2} - assert {x[2] for x in res[addr]} == {4, 5} + assert rank in res + assert len(res[rank]) == 2 + assert {x[0] for x in res[rank]} == {123, 789} + assert {x[1] for x in res[rank]} == {0, 2} + assert {x[2] for x in res[rank]} == {4, 5} # query with extra info res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5), (110, 6)]) assert len(res) == 1 - assert addr in res - assert len(res[addr]) == 2 - assert {x[0] for x in res[addr]} == {123, 789} - assert {x[1] for x in res[addr]} == {0, 2} - assert {x[2] for x in res[addr]} == {4, 5} + assert rank in res + assert len(res[rank]) == 2 + assert {x[0] for x in res[rank]} == {123, 789} + assert {x[1] for x in res[rank]} == {0, 2} + assert {x[2] for x in res[rank]} == {4, 5} # unregister block - await tracker_ref.unregister_block(virtual_engine, addr, 1) + await tracker_ref.unregister_block(virtual_engine, rank, 1) res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)]) assert len(res) == 1 - assert addr in res - assert len(res[addr]) == 1 - assert {x[0] for x in res[addr]} == {123} - assert {x[1] for x in res[addr]} == { + assert rank in res + assert len(res[rank]) == 1 + assert {x[0] for x in res[rank]} == {123} + assert {x[1] for x in res[rank]} == { 0, } - assert {x[2] for x in res[addr]} == { + assert {x[2] for x in res[rank]} == { 4, } # nothing happens - await tracker_ref.unregister_block(virtual_engine, addr, 3) + await tracker_ref.unregister_block(virtual_engine, rank, 3) res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)]) assert len(res) == 1 - assert addr in res - assert len(res[addr]) == 1 - assert {x[0] for x in res[addr]} == {123} - assert {x[1] for x in res[addr]} == { + assert rank in res + assert len(res[rank]) == 1 + assert {x[0] for x in res[rank]} == {123} + assert {x[1] for x in res[rank]} == { 0, } - assert {x[2] for x in res[addr]} == { + assert {x[2] for x in res[rank]} == { 4, } # query returns empty @@ -107,16 +108,40 @@ async def test_block_tracker(actor_pool_context): assert res == {} # check internal data - hash_to_address_and_block_id = await tracker_ref.get_hash_to_address_and_block_id() - assert virtual_engine in hash_to_address_and_block_id - assert hash_to_address_and_block_id[virtual_engine] == { + hash_to_rank_and_block_id = await tracker_ref.get_hash_to_rank_and_block_id() + assert virtual_engine in hash_to_rank_and_block_id + assert hash_to_rank_and_block_id[virtual_engine] == { 123: { - (addr, 0), + (rank, 0), }, 456: set(), - 789: {(addr, 2)}, + 789: {(rank, 2)}, } - address_to_hash_and_block_id = await tracker_ref.get_address_to_hash_and_block_id() - assert virtual_engine in address_to_hash_and_block_id - assert address_to_hash_and_block_id[virtual_engine] == {addr: {(123, 0), (789, 2)}} + rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id() + assert virtual_engine in rank_to_hash_and_block_id + assert rank_to_hash_and_block_id[virtual_engine] == {rank: {(123, 0), (789, 2)}} + + # register blocks + new_rank = 1 + block_infos = [(111, 7), (222, 8), (333, 9), (123, 10)] + await tracker_ref.register_blocks(virtual_engine, block_infos, new_rank) + + # test unregister rank + await tracker_ref.unregister_rank(0) + res = await tracker_ref.query_blocks(virtual_engine, [(789, 5)]) + assert len(res) == 0 + res = await tracker_ref.query_blocks(virtual_engine, [(123, 6)]) + assert len(res) == 1 + assert new_rank in res + + # check internal data + rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id() + assert rank in rank_to_hash_and_block_id[virtual_engine] + assert new_rank in rank_to_hash_and_block_id[virtual_engine] + + # test register rank + await tracker_ref.register_rank(0) + rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id() + assert rank not in rank_to_hash_and_block_id[virtual_engine] + assert new_rank in rank_to_hash_and_block_id[virtual_engine] diff --git a/xinference/model/llm/vllm/xavier/transfer.py b/xinference/model/llm/vllm/xavier/transfer.py index 0723e8aefb..92640ddb96 100644 --- a/xinference/model/llm/vllm/xavier/transfer.py +++ b/xinference/model/llm/vllm/xavier/transfer.py @@ -23,6 +23,8 @@ from vllm.utils import TORCH_DTYPE_TO_NUMPY_DTYPE, Device from vllm.worker.cache_engine import CacheEngine +from .collective import CollectiveRank + logger = logging.getLogger(__name__) @@ -89,7 +91,7 @@ def get_gloo_dtype(self, input_dtype: torch.dtype): return TypeMappingGloo[TORCH_DTYPE_TO_NUMPY_DTYPE[input_dtype]] -class TransferActor(xo.StatelessActor, BufferTransferMixin): +class TransferActor(xo.StatelessActor, BufferTransferMixin, CollectiveRank): @classmethod def default_uid(cls): return f"vllm-transfer-actor" @@ -104,38 +106,21 @@ def __init__( world_addresses: List[str], ): super().__init__() - self._rank = rank - self._world_size = world_size - self._store_address = store_address - self._rank_address = rank_address - self._store_port = store_port - self._world_addresses = world_addresses - self._context = None + CollectiveRank.__init__( + self, + rank, + world_size, + rank_address, + store_address, + store_port, + world_addresses, + ) self._cache_engine: Optional[List[CacheEngine]] = None self._scheduler: Optional[List[Scheduler]] = None self._swap_stream = torch.cuda.Stream() async def __post_create__(self): - from xoscar.collective import xoscar_pygloo as xp - - context = xp.rendezvous.Context(self._rank, self._world_size) - - attr = xp.transport.tcp.attr(self._rank_address.split(":")[0]) - dev = xp.transport.tcp.CreateDevice(attr) - - opt = xp.rendezvous.TCPStoreOptions() - opt.port = self._store_port - opt.numWorkers = self._world_size - opt.isServer = self._rank == 0 - - store = xp.rendezvous.TCPStore(self._store_address, opt) - store = xp.rendezvous.PrefixStore(str(self._world_size), store) - - context.connectFullMesh(store, dev) - self._context = context - logger.debug( - f"Rank {self._rank} arrives successfully, world addresses: {self._world_addresses}" - ) + self.init_rank() def setup( self, @@ -153,6 +138,9 @@ def setup( num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory ) + async def __pre_destroy__(self): + self._context.closeConnections() + def _get_cache_engine(self, virtual_engine: int) -> CacheEngine: return self._cache_engine[virtual_engine] # type: ignore @@ -281,18 +269,51 @@ async def do_recv( self.free_buffer_index(cpu_buf_index) async def recv( - self, virtual_engine: int, from_address: str, src_to_dst: Dict[int, int] + self, virtual_engine: int, from_rank: int, src_to_dst: Dict[int, int] ): """ This is the external entry point for the call. The transfer logic is as follows: the receiver requests the sender to send the data directly to itself in a point-to-point manner. """ - rank = self._world_addresses.index(from_address) + from_address = self._world_addresses[from_rank] sender_ref = await xo.actor_ref( - address=from_address, uid=f"{TransferActor.default_uid()}-{rank}" + address=from_address, uid=f"{TransferActor.default_uid()}-{from_rank}" ) await asyncio.gather( sender_ref.do_send(virtual_engine, self._rank, src_to_dst), - self.do_recv(virtual_engine, rank, src_to_dst), + self.do_recv(virtual_engine, from_rank, src_to_dst), ) + + +class Rank0TransferActor(xo.StatelessActor, CollectiveRank): + """ + The Rank 0 transfer actor is only used for constructing the collective communication world, + so it only needs to inherit the `CollectiveWorld` class. + """ + + @classmethod + def default_uid(cls): + return f"vllm-transfer-actor" + + def __init__( + self, + rank: int, + world_size: int, + rank_address: str, + store_address: str, + store_port: int, + world_addresses: List[str], + ): + CollectiveRank.__init__( + self, + rank, + world_size, + rank_address, + store_address, + store_port, + world_addresses, + ) + + async def __post_create__(self): + self.init_rank()