Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Launch model by version #896

Merged
merged 14 commits into from
Jan 22, 2024
54 changes: 54 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/{model_type}/{model_name}/versions",
self.get_model_versions,
methods=["GET"],
dependencies=[Security(verify_token, scopes=["models:list"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models",
self.list_models,
Expand All @@ -280,6 +288,14 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/instance",
self.launch_model_by_version,
methods=["POST"],
dependencies=[Security(verify_token, scopes=["models:start"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models",
self.launch_model,
Expand Down Expand Up @@ -640,6 +656,44 @@ async def get_instance_info(
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=infos)

async def launch_model_by_version(
self, request: Request, wait_ready: bool = Query(True)
) -> JSONResponse:
payload = await request.json()
model_uid = payload.get("model_uid")
model_type = payload.get("model_type")
model_version = payload.get("model_version")
replica = payload.get("replica", 1)
n_gpu = payload.get("n_gpu", "auto")

try:
model_uid = await (
await self._get_supervisor_ref()
).launch_model_by_version(
model_uid=model_uid,
model_type=model_type,
model_version=model_version,
replica=replica,
n_gpu=n_gpu,
wait_ready=wait_ready,
)
except Exception as e:
logger.error(str(e), exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content={"model_uid": model_uid})

async def get_model_versions(
self, model_type: str, model_name: str
) -> JSONResponse:
try:
content = await (await self._get_supervisor_ref()).get_model_versions(
model_type, model_name
)
return JSONResponse(content=content)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def build_gradio_interface(
self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
) -> JSONResponse:
Expand Down
99 changes: 99 additions & 0 deletions xinference/core/cache_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022-2024 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Dict, List, Optional

import xoscar as xo

logger = getLogger(__name__)


class CacheTrackerActor(xo.Actor):
def __init__(self):
super().__init__()
self._model_name_to_version_info: Dict[str, List[Dict]] = {}

@classmethod
def uid(cls) -> str:
return "cache_tracker"

@staticmethod
def _map_address_to_file_location(
model_version: Dict[str, List[Dict]], address: str
):
for model_name, model_versions in model_version.items():
for info_dict in model_versions:
info_dict["model_file_location"] = (
{address: info_dict["model_file_location"]}
if info_dict["cache_status"]
else None
)

@staticmethod
def _update_file_location(data: Dict, origin_version_info: Dict):
if origin_version_info["model_file_location"] is None:
origin_version_info["model_file_location"] = data
else:
assert isinstance(origin_version_info["model_file_location"], dict)
origin_version_info["model_file_location"].update(data)

def record_model_version(self, version_info: Dict[str, List[Dict]], address: str):
self._map_address_to_file_location(version_info, address)
for model_name, model_versions in version_info.items():
if model_name not in self._model_name_to_version_info:
self._model_name_to_version_info[model_name] = model_versions
else:
assert len(model_versions) == len(
self._model_name_to_version_info[model_name]
), "Model version info inconsistency between supervisor and worker"
for version, origin_version in zip(
model_versions, self._model_name_to_version_info[model_name]
):
if (
version["cache_status"]
and version["model_file_location"] is not None
):
origin_version["cache_status"] = True
self._update_file_location(
version["model_file_location"], origin_version
)

def update_cache_status(
self,
address: str,
model_name: str,
model_version: Optional[str],
model_path: str,
):
if model_name not in self._model_name_to_version_info:
logger.warning(f"Not record version info for {model_name} for now.")
else:
for version_info in self._model_name_to_version_info[model_name]:
if model_version is None: # image model
self._update_file_location({address: model_path}, version_info)
version_info["cache_status"] = True
else:
if version_info["model_version"] == model_version:
self._update_file_location({address: model_path}, version_info)
version_info["cache_status"] = True

def unregister_model_version(self, model_name: str):
self._model_name_to_version_info.pop(model_name, None)

def get_model_versions(self, model_name: str) -> List[Dict]:
if model_name not in self._model_name_to_version_info:
logger.warning(f"Not record version info for model_name: {model_name}")
return []
else:
return self._model_name_to_version_info[model_name]
1 change: 1 addition & 0 deletions xinference/core/status_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LaunchStatus(Enum):
class InstanceInfo(BaseModel):
model_name: str
model_uid: str
model_version: Optional[str]
model_ability: List[str]
replica: int
status: str
Expand Down
95 changes: 89 additions & 6 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
iter_replica_model_uid,
log_async,
log_sync,
parse_model_version,
parse_replica_model_uid,
)

Expand Down Expand Up @@ -89,37 +90,74 @@ async def __post_create__(self):
# comment this line to avoid worker lost
# self._check_dead_nodes_task = asyncio.create_task(self._check_dead_nodes())
logger.info(f"Xinference supervisor {self.address} started")
from .cache_tracker import CacheTrackerActor
from .status_guard import StatusGuardActor

self._status_guard_ref: xo.ActorRefType[
"StatusGuardActor"
] = await xo.create_actor(
StatusGuardActor, address=self.address, uid=StatusGuardActor.uid()
)
self._cache_tracker_ref: xo.ActorRefType[
"CacheTrackerActor"
] = await xo.create_actor(
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.uid()
)

from ..model.embedding import (
CustomEmbeddingModelSpec,
generate_embedding_description,
get_embedding_model_descriptions,
register_embedding,
unregister_embedding,
)
from ..model.llm import register_llm, unregister_llm
from ..model.llm.llm_family import CustomLLMFamilyV1
from ..model.rerank.custom import (
from ..model.image import get_image_model_descriptions
from ..model.llm import (
CustomLLMFamilyV1,
generate_llm_description,
get_llm_model_descriptions,
register_llm,
unregister_llm,
)
from ..model.rerank import (
CustomRerankModelSpec,
generate_rerank_description,
get_rerank_model_descriptions,
register_rerank,
unregister_rerank,
)

self._custom_register_type_to_cls: Dict[str, Tuple] = {
"LLM": (CustomLLMFamilyV1, register_llm, unregister_llm),
"LLM": (
CustomLLMFamilyV1,
register_llm,
unregister_llm,
generate_llm_description,
),
"embedding": (
CustomEmbeddingModelSpec,
register_embedding,
unregister_embedding,
generate_embedding_description,
),
"rerank": (
CustomRerankModelSpec,
register_rerank,
unregister_rerank,
generate_rerank_description,
),
"rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
}

# record model version
model_version_infos: Dict[str, List[Dict]] = {}
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)

@staticmethod
async def get_builtin_prompts() -> Dict[str, Any]:
from ..model.llm.llm_family import BUILTIN_LLM_PROMPT_STYLE
Expand Down Expand Up @@ -420,6 +458,7 @@ async def register_model(self, model_type: str, model: str, persist: bool):
model_spec_cls,
register_fn,
unregister_fn,
generate_fn,
) = self._custom_register_type_to_cls[model_type]

if not self.is_local_deployment():
Expand All @@ -430,6 +469,9 @@ async def register_model(self, model_type: str, model: str, persist: bool):
model_spec = model_spec_cls.parse_raw(model)
try:
register_fn(model_spec, persist)
await self._cache_tracker_ref.record_model_version(
generate_fn(model_spec), self.address
)
except Exception as e:
unregister_fn(model_spec.model_name, raise_error=False)
raise e
Expand All @@ -439,8 +481,9 @@ async def register_model(self, model_type: str, model: str, persist: bool):
@log_async(logger=logger)
async def unregister_model(self, model_type: str, model_name: str):
if model_type in self._custom_register_type_to_cls:
_, _, unregister_fn = self._custom_register_type_to_cls[model_type]
_, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
unregister_fn(model_name)
await self._cache_tracker_ref.unregister_model_version(model_name)

if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
Expand All @@ -457,6 +500,44 @@ def _gen_model_uid(self, model_name: str) -> str:
)
return f"{model_name}-{gen_random_string(8)}"

@log_async(logger=logger)
async def get_model_versions(self, model_type: str, model_name: str) -> List[Dict]:
logger.debug(
f"Get model versions of model_name: {model_name}, model_type: {model_type}"
)
return await self._cache_tracker_ref.get_model_versions(model_name)

@log_async(logger=logger)
async def launch_model_by_version(
self,
model_uid: Optional[str],
model_type: str,
model_version: str,
replica: int = 1,
n_gpu: Optional[Union[int, str]] = "auto",
wait_ready: bool = True,
):
parse_results = parse_model_version(model_version, model_type)

if model_type == "image" and len(parse_results) == 2:
kwargs = {"controlnet": parse_results[1]}
else:
kwargs = {}

return await self.launch_builtin_model(
model_uid=model_uid,
model_name=parse_results[0],
model_size_in_billions=parse_results[1] if model_type == "LLM" else None,
model_format=parse_results[2] if model_type == "LLM" else None,
quantization=parse_results[3] if model_type == "LLM" else None,
model_type=model_type,
replica=replica,
n_gpu=n_gpu,
wait_ready=wait_ready,
model_version=model_version,
**kwargs,
)

async def launch_speculative_llm(
self,
model_uid: Optional[str],
Expand Down Expand Up @@ -529,6 +610,7 @@ async def launch_builtin_model(
n_gpu: Optional[Union[int, str]] = "auto",
request_limits: Optional[int] = None,
wait_ready: bool = True,
model_version: Optional[str] = None,
**kwargs,
) -> str:
if model_uid is None:
Expand Down Expand Up @@ -601,6 +683,7 @@ async def _launch_model():
instance_info = InstanceInfo(
model_name=model_name,
model_uid=model_uid,
model_version=model_version,
model_ability=[],
replica=replica,
status=LaunchStatus.CREATING.name,
Expand Down
Loading
Loading