Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-441): restart model service process #3282

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3282.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add model service process restart (`POST /{service_id}/routings/{route_id}/restart`) API
228 changes: 74 additions & 154 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import aiotools
import attrs
import pkg_resources
import yaml
import zmq
import zmq.asyncio
from async_timeout import timeout
Expand All @@ -63,11 +62,10 @@
stop_after_delay,
wait_fixed,
)
from trafaret import DataError

from ai.backend.agent.model_service import ModelServiceManager
from ai.backend.common import msgpack, redis_helper
from ai.backend.common.bgtask import BackgroundTaskManager
from ai.backend.common.config import model_definition_iv
from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
from ai.backend.common.events import (
Expand Down Expand Up @@ -105,7 +103,6 @@
from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext
from ai.backend.common.service_ports import parse_service_ports
from ai.backend.common.types import (
MODEL_SERVICE_RUNTIME_PROFILES,
AbuseReportValue,
AcceleratorMetadata,
AgentId,
Expand Down Expand Up @@ -162,6 +159,7 @@
ContainerStatus,
KernelLifecycleStatus,
LifecycleEvent,
ModelServiceInfo,
MountInfo,
)
from .utils import generate_local_instance_id, get_arch_name
Expand Down Expand Up @@ -341,6 +339,8 @@ async def prepare_container(
environ: Mapping[str, str],
service_ports,
cluster_info: ClusterInfo,
*,
model_service_info: ModelServiceInfo | None = None,
) -> KernelObjectType:
raise NotImplementedError

Expand Down Expand Up @@ -2013,12 +2013,7 @@ async def create_kernel(
})

model_definition: Optional[Mapping[str, Any]] = None
# Read model config
model_folders = [
folder
for folder in vfolder_mounts
if folder.usage_mode == VFolderUsageMode.MODEL
]
model_service_info: ModelServiceInfo | None = None

if ctx.kernel_config["cluster_role"] in ("main", "master"):
for sport in parse_service_ports(
Expand Down Expand Up @@ -2063,14 +2058,45 @@ async def create_kernel(
exposed_ports.append(port)
log.debug("exposed ports: {!r}", exposed_ports)
if kernel_config["session_type"] == SessionTypes.INFERENCE:
model_definition = await self.load_model_definition(
RuntimeVariant(
(kernel_config["internal_data"] or {}).get("runtime_variant", "custom")
),
model_folders,
environ,
service_ports,
kernel_config,
# Read model config
model_folders = [
folder
for folder in vfolder_mounts
if folder.usage_mode == VFolderUsageMode.MODEL
]

if len(model_folders) == 0:
raise AgentError("No model folder loaded for inference session")
model_folder = model_folders[0]
runtime_variant = RuntimeVariant(
(kernel_config["internal_data"] or {}).get("runtime_variant", "custom")
)

image_command = await self.extract_image_command(
kernel_config["image"]["canonical"]
)
model_definition_path: str | None = (
kernel_config.get("internal_data") or {}
).get("model_definition_path")

model_service_manager = ModelServiceManager(
runtime_variant, model_folder, model_definition_path=model_definition_path
)
model_definition = await model_service_manager.load_model_definition(
image_command=image_command,
)
environ.update(model_service_manager.create_environs(model_definition))
service_ports.extend(
model_service_manager.create_service_port_definitions(
model_definition,
service_ports,
)
)

model_service_info = ModelServiceInfo(
runtime_variant,
model_folder,
model_definition_path,
)

runtime_type = image_labels.get("ai.backend.runtime-type", "app")
Expand Down Expand Up @@ -2128,6 +2154,7 @@ async def create_kernel(
environ,
service_ports,
cluster_info,
model_service_info=model_service_info,
)
async with self.registry_lock:
self.kernel_registry[kernel_id] = kernel_obj
Expand Down Expand Up @@ -2240,6 +2267,7 @@ async def create_kernel(
asyncio.create_task(
self.start_and_monitor_model_service_health(kernel_obj, model)
)
kernel_obj.current_model_definition = model_definition

# Finally we are done.
await self.produce_event(
Expand Down Expand Up @@ -2282,148 +2310,40 @@ async def start_and_monitor_model_service_health(
)
)

async def load_model_definition(
async def restart_model_service(
self,
runtime_variant: RuntimeVariant,
model_folders: list[VFolderMount],
environ: MutableMapping[str, Any],
service_ports: list[ServicePort],
kernel_config: KernelCreationConfig,
) -> Any:
image_command = await self.extract_image_command(kernel_config["image"]["canonical"])
if runtime_variant != RuntimeVariant.CUSTOM and not image_command:
kernel_id: KernelId,
) -> None:
try:
kernel_obj = self.kernel_registry[kernel_id]
except KeyError:
raise AgentError(f"Kernel {kernel_id} not found")
if not kernel_obj.model_service_info:
raise AgentError(
"image should have its own command when runtime variant is set to values other than CUSTOM!"
"Model service info not loaded on kernel. Perhaps your kernel is not new enough to call this function."
)
assert len(model_folders) > 0
model_folder: VFolderMount = model_folders[0]

match runtime_variant:
case RuntimeVariant.VLLM:
_model = {
"name": "vllm-model",
"model_path": model_folder.kernel_path.as_posix(),
"service": {
"start_command": image_command,
"port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port,
"health_check": {
"path": MODEL_SERVICE_RUNTIME_PROFILES[
runtime_variant
].health_check_endpoint,
},
},
}
raw_definition = {"models": [_model]}

case RuntimeVariant.HUGGINGFACE_TGI:
_model = {
"name": "tgi-model",
"model_path": model_folder.kernel_path.as_posix(),
"service": {
"start_command": image_command,
"port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port,
"health_check": {
"path": MODEL_SERVICE_RUNTIME_PROFILES[
runtime_variant
].health_check_endpoint,
},
},
}
raw_definition = {"models": [_model]}

case RuntimeVariant.NIM:
_model = {
"name": "nim-model",
"model_path": model_folder.kernel_path.as_posix(),
"service": {
"start_command": image_command,
"port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port,
"health_check": {
"path": MODEL_SERVICE_RUNTIME_PROFILES[
runtime_variant
].health_check_endpoint,
},
},
}
raw_definition = {"models": [_model]}

case RuntimeVariant.CMD:
_model = {
"name": "image-model",
"model_path": model_folder.kernel_path.as_posix(),
"service": {
"start_command": image_command,
"port": 8000,
},
}
raw_definition = {"models": [_model]}

case RuntimeVariant.CUSTOM:
if _fname := (kernel_config.get("internal_data") or {}).get(
"model_definition_path"
):
model_definition_candidates = [_fname]
else:
model_definition_candidates = [
"model-definition.yaml",
"model-definition.yml",
]
model_service_manager = ModelServiceManager(
kernel_obj.model_service_info.runtime_variant,
kernel_obj.model_service_info.model_folder,
model_definition_path=kernel_obj.model_service_info.model_definition_path,
)

model_definition_path = None
for filename in model_definition_candidates:
if (Path(model_folder.host_path) / filename).is_file():
model_definition_path = Path(model_folder.host_path) / filename
break
image_command = await self.extract_image_command(kernel_obj.image.canonical)
model_definition = await model_service_manager.load_model_definition(
image_command=image_command,
)

if not model_definition_path:
raise AgentError(
f"Model definition file ({" or ".join(model_definition_candidates)}) does not exist under vFolder"
f" {model_folder.name} (ID {model_folder.vfid})",
)
try:
model_definition_yaml = await asyncio.get_running_loop().run_in_executor(
None, model_definition_path.read_text
)
except FileNotFoundError as e:
raise AgentError(
"Model definition file (model-definition.yml) does not exist under"
f" vFolder {model_folder.name} (ID {model_folder.vfid})",
) from e
try:
raw_definition = yaml.load(model_definition_yaml, Loader=yaml.FullLoader)
except yaml.error.YAMLError as e:
raise AgentError(f"Invalid YAML syntax: {e}") from e
try:
model_definition = model_definition_iv.check(raw_definition)
assert model_definition is not None
for model in model_definition["models"]:
if "BACKEND_MODEL_NAME" not in environ:
environ["BACKEND_MODEL_NAME"] = model["name"]
environ["BACKEND_MODEL_PATH"] = model["model_path"]
if service := model.get("service"):
if service["port"] in (2000, 2001):
raise AgentError("Port 2000 and 2001 are reserved for internal use")
overlapping_services = [
s for s in service_ports if service["port"] in s["container_ports"]
]
if len(overlapping_services) > 0:
raise AgentError(
f"Port {service["port"]} overlaps with built-in service"
f" {overlapping_services[0]["name"]}"
)
service_ports.append({
"name": f"{model["name"]}-{service["port"]}",
"protocol": ServicePortProtocols.PREOPEN,
"container_ports": (service["port"],),
"host_ports": (None,),
"is_inference": True,
})
return model_definition
except DataError as e:
raise AgentError(
"Failed to validate model definition from vFolder"
f" {model_folder.name} (ID {model_folder.vfid})",
) from e
for model_info in kernel_obj.current_model_definition["models"]:
log.debug("Shutting down model service {}", model_info["name"])
await kernel_obj.shutdown_model_service(model_info)

for model_info in model_definition["models"]:
log.debug("Starting model service {}", model_info["name"])
await self.start_and_monitor_model_service_health(
cast(KernelObjectType, kernel_obj), model_info
)
kernel_obj.current_model_definition = model_definition

def get_public_service_ports(self, service_ports: list[ServicePort]) -> list[ServicePort]:
return [port for port in service_ports if port["protocol"] != ServicePortProtocols.INTERNAL]
Expand Down
13 changes: 12 additions & 1 deletion src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@
from ..resources import AbstractComputePlugin, KernelResourceSpec, Mount, known_slot_types
from ..scratch import create_loop_filesystem, destroy_loop_filesystem
from ..server import get_extra_volumes
from ..types import AgentEventData, Container, ContainerStatus, LifecycleEvent, MountInfo, Port
from ..types import (
AgentEventData,
Container,
ContainerStatus,
LifecycleEvent,
ModelServiceInfo,
MountInfo,
Port,
)
from ..utils import (
closing_async,
container_pid_to_host_pid,
Expand Down Expand Up @@ -640,6 +648,8 @@ async def prepare_container(
environ: Mapping[str, str],
service_ports: List[ServicePort],
cluster_info: ClusterInfo,
*,
model_service_info: ModelServiceInfo | None = None,
) -> DockerKernel:
loop = current_loop()

Expand Down Expand Up @@ -764,6 +774,7 @@ def _populate_ssh_config():
resource_spec=resource_spec,
environ=environ,
data={},
model_service_info=model_service_info,
)
return kernel_obj

Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/agent/docker/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from ..kernel import AbstractCodeRunner, AbstractKernel
from ..resources import KernelResourceSpec
from ..types import AgentEventData
from ..types import AgentEventData, ModelServiceInfo
from ..utils import closing_async, get_arch_name

log = BraceStyleAdapter(logging.getLogger(__spec__.name))
Expand All @@ -58,6 +58,7 @@ def __init__(
service_ports: Any, # TODO: type-annotation
environ: Mapping[str, Any],
data: Dict[str, Any],
model_service_info: Optional[ModelServiceInfo],
) -> None:
super().__init__(
kernel_id,
Expand All @@ -71,6 +72,7 @@ def __init__(
service_ports=service_ports,
data=data,
environ=environ,
model_service_info=model_service_info,
)

self.network_driver = network_driver
Expand Down Expand Up @@ -153,6 +155,10 @@ async def shutdown_service(self, service: str):
assert self.runner is not None
await self.runner.feed_shutdown_service(service)

async def shutdown_model_service(self, model_service: Mapping[str, Any]):
assert self.runner is not None
await self.runner.feed_shutdown_model_service(model_service)

async def get_service_apps(self):
assert self.runner is not None
result = await self.runner.feed_service_apps()
Expand Down
Loading
Loading