Skip to content

Commit

Permalink
ENH: Launch model by version (#896)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Jan 22, 2024
1 parent 235a8d9 commit 99988c6
Show file tree
Hide file tree
Showing 23 changed files with 776 additions and 48 deletions.
54 changes: 54 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/{model_type}/{model_name}/versions",
self.get_model_versions,
methods=["GET"],
dependencies=[Security(verify_token, scopes=["models:list"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models",
self.list_models,
Expand All @@ -280,6 +288,14 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/instance",
self.launch_model_by_version,
methods=["POST"],
dependencies=[Security(verify_token, scopes=["models:start"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models",
self.launch_model,
Expand Down Expand Up @@ -640,6 +656,44 @@ async def get_instance_info(
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=infos)

async def launch_model_by_version(
self, request: Request, wait_ready: bool = Query(True)
) -> JSONResponse:
payload = await request.json()
model_uid = payload.get("model_uid")
model_type = payload.get("model_type")
model_version = payload.get("model_version")
replica = payload.get("replica", 1)
n_gpu = payload.get("n_gpu", "auto")

try:
model_uid = await (
await self._get_supervisor_ref()
).launch_model_by_version(
model_uid=model_uid,
model_type=model_type,
model_version=model_version,
replica=replica,
n_gpu=n_gpu,
wait_ready=wait_ready,
)
except Exception as e:
logger.error(str(e), exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content={"model_uid": model_uid})

async def get_model_versions(
self, model_type: str, model_name: str
) -> JSONResponse:
try:
content = await (await self._get_supervisor_ref()).get_model_versions(
model_type, model_name
)
return JSONResponse(content=content)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def build_gradio_interface(
self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
) -> JSONResponse:
Expand Down
99 changes: 99 additions & 0 deletions xinference/core/cache_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022-2024 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Dict, List, Optional

import xoscar as xo

logger = getLogger(__name__)


class CacheTrackerActor(xo.Actor):
def __init__(self):
super().__init__()
self._model_name_to_version_info: Dict[str, List[Dict]] = {}

@classmethod
def uid(cls) -> str:
return "cache_tracker"

@staticmethod
def _map_address_to_file_location(
model_version: Dict[str, List[Dict]], address: str
):
for model_name, model_versions in model_version.items():
for info_dict in model_versions:
info_dict["model_file_location"] = (
{address: info_dict["model_file_location"]}
if info_dict["cache_status"]
else None
)

@staticmethod
def _update_file_location(data: Dict, origin_version_info: Dict):
if origin_version_info["model_file_location"] is None:
origin_version_info["model_file_location"] = data
else:
assert isinstance(origin_version_info["model_file_location"], dict)
origin_version_info["model_file_location"].update(data)

def record_model_version(self, version_info: Dict[str, List[Dict]], address: str):
self._map_address_to_file_location(version_info, address)
for model_name, model_versions in version_info.items():
if model_name not in self._model_name_to_version_info:
self._model_name_to_version_info[model_name] = model_versions
else:
assert len(model_versions) == len(
self._model_name_to_version_info[model_name]
), "Model version info inconsistency between supervisor and worker"
for version, origin_version in zip(
model_versions, self._model_name_to_version_info[model_name]
):
if (
version["cache_status"]
and version["model_file_location"] is not None
):
origin_version["cache_status"] = True
self._update_file_location(
version["model_file_location"], origin_version
)

def update_cache_status(
self,
address: str,
model_name: str,
model_version: Optional[str],
model_path: str,
):
if model_name not in self._model_name_to_version_info:
logger.warning(f"Not record version info for {model_name} for now.")
else:
for version_info in self._model_name_to_version_info[model_name]:
if model_version is None: # image model
self._update_file_location({address: model_path}, version_info)
version_info["cache_status"] = True
else:
if version_info["model_version"] == model_version:
self._update_file_location({address: model_path}, version_info)
version_info["cache_status"] = True

def unregister_model_version(self, model_name: str):
self._model_name_to_version_info.pop(model_name, None)

def get_model_versions(self, model_name: str) -> List[Dict]:
if model_name not in self._model_name_to_version_info:
logger.warning(f"Not record version info for model_name: {model_name}")
return []
else:
return self._model_name_to_version_info[model_name]
1 change: 1 addition & 0 deletions xinference/core/status_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LaunchStatus(Enum):
class InstanceInfo(BaseModel):
model_name: str
model_uid: str
model_version: Optional[str]
model_ability: List[str]
replica: int
status: str
Expand Down
95 changes: 89 additions & 6 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
iter_replica_model_uid,
log_async,
log_sync,
parse_model_version,
parse_replica_model_uid,
)

Expand Down Expand Up @@ -89,37 +90,74 @@ async def __post_create__(self):
# comment this line to avoid worker lost
# self._check_dead_nodes_task = asyncio.create_task(self._check_dead_nodes())
logger.info(f"Xinference supervisor {self.address} started")
from .cache_tracker import CacheTrackerActor
from .status_guard import StatusGuardActor

self._status_guard_ref: xo.ActorRefType[
"StatusGuardActor"
] = await xo.create_actor(
StatusGuardActor, address=self.address, uid=StatusGuardActor.uid()
)
self._cache_tracker_ref: xo.ActorRefType[
"CacheTrackerActor"
] = await xo.create_actor(
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.uid()
)

from ..model.embedding import (
CustomEmbeddingModelSpec,
generate_embedding_description,
get_embedding_model_descriptions,
register_embedding,
unregister_embedding,
)
from ..model.llm import register_llm, unregister_llm
from ..model.llm.llm_family import CustomLLMFamilyV1
from ..model.rerank.custom import (
from ..model.image import get_image_model_descriptions
from ..model.llm import (
CustomLLMFamilyV1,
generate_llm_description,
get_llm_model_descriptions,
register_llm,
unregister_llm,
)
from ..model.rerank import (
CustomRerankModelSpec,
generate_rerank_description,
get_rerank_model_descriptions,
register_rerank,
unregister_rerank,
)

self._custom_register_type_to_cls: Dict[str, Tuple] = {
"LLM": (CustomLLMFamilyV1, register_llm, unregister_llm),
"LLM": (
CustomLLMFamilyV1,
register_llm,
unregister_llm,
generate_llm_description,
),
"embedding": (
CustomEmbeddingModelSpec,
register_embedding,
unregister_embedding,
generate_embedding_description,
),
"rerank": (
CustomRerankModelSpec,
register_rerank,
unregister_rerank,
generate_rerank_description,
),
"rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
}

# record model version
model_version_infos: Dict[str, List[Dict]] = {}
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)

@staticmethod
async def get_builtin_prompts() -> Dict[str, Any]:
from ..model.llm.llm_family import BUILTIN_LLM_PROMPT_STYLE
Expand Down Expand Up @@ -420,6 +458,7 @@ async def register_model(self, model_type: str, model: str, persist: bool):
model_spec_cls,
register_fn,
unregister_fn,
generate_fn,
) = self._custom_register_type_to_cls[model_type]

if not self.is_local_deployment():
Expand All @@ -430,6 +469,9 @@ async def register_model(self, model_type: str, model: str, persist: bool):
model_spec = model_spec_cls.parse_raw(model)
try:
register_fn(model_spec, persist)
await self._cache_tracker_ref.record_model_version(
generate_fn(model_spec), self.address
)
except Exception as e:
unregister_fn(model_spec.model_name, raise_error=False)
raise e
Expand All @@ -439,8 +481,9 @@ async def register_model(self, model_type: str, model: str, persist: bool):
@log_async(logger=logger)
async def unregister_model(self, model_type: str, model_name: str):
if model_type in self._custom_register_type_to_cls:
_, _, unregister_fn = self._custom_register_type_to_cls[model_type]
_, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
unregister_fn(model_name)
await self._cache_tracker_ref.unregister_model_version(model_name)

if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
Expand All @@ -457,6 +500,44 @@ def _gen_model_uid(self, model_name: str) -> str:
)
return f"{model_name}-{gen_random_string(8)}"

@log_async(logger=logger)
async def get_model_versions(self, model_type: str, model_name: str) -> List[Dict]:
logger.debug(
f"Get model versions of model_name: {model_name}, model_type: {model_type}"
)
return await self._cache_tracker_ref.get_model_versions(model_name)

@log_async(logger=logger)
async def launch_model_by_version(
self,
model_uid: Optional[str],
model_type: str,
model_version: str,
replica: int = 1,
n_gpu: Optional[Union[int, str]] = "auto",
wait_ready: bool = True,
):
parse_results = parse_model_version(model_version, model_type)

if model_type == "image" and len(parse_results) == 2:
kwargs = {"controlnet": parse_results[1]}
else:
kwargs = {}

return await self.launch_builtin_model(
model_uid=model_uid,
model_name=parse_results[0],
model_size_in_billions=parse_results[1] if model_type == "LLM" else None,
model_format=parse_results[2] if model_type == "LLM" else None,
quantization=parse_results[3] if model_type == "LLM" else None,
model_type=model_type,
replica=replica,
n_gpu=n_gpu,
wait_ready=wait_ready,
model_version=model_version,
**kwargs,
)

async def launch_speculative_llm(
self,
model_uid: Optional[str],
Expand Down Expand Up @@ -529,6 +610,7 @@ async def launch_builtin_model(
n_gpu: Optional[Union[int, str]] = "auto",
request_limits: Optional[int] = None,
wait_ready: bool = True,
model_version: Optional[str] = None,
**kwargs,
) -> str:
if model_uid is None:
Expand Down Expand Up @@ -601,6 +683,7 @@ async def _launch_model():
instance_info = InstanceInfo(
model_name=model_name,
model_uid=model_uid,
model_version=model_version,
model_ability=[],
replica=replica,
status=LaunchStatus.CREATING.name,
Expand Down
Loading

0 comments on commit 99988c6

Please sign in to comment.