Skip to content

Commit

Permalink
FEAT: Auto recover limit (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 18, 2024
1 parent 96277bb commit 45f1d49
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
29 changes: 28 additions & 1 deletion xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,13 @@ def test_client_custom_embedding_model(setup):
assert custom_model_reg is None


@pytest.fixture
def set_auto_recover_limit():
os.environ["XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT"] = "1"
yield
del os.environ["XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT"]


@pytest.fixture
def setup_cluster():
import xoscar as xo
Expand Down Expand Up @@ -604,7 +611,7 @@ def setup_cluster():
local_cluster.terminate()


def test_auto_recover(setup_cluster):
def test_auto_recover(set_auto_recover_limit, setup_cluster):
endpoint, _ = setup_cluster
current_proc = psutil.Process()
chilren_proc = set(current_proc.children(recursive=True))
Expand Down Expand Up @@ -636,3 +643,23 @@ def test_auto_recover(setup_cluster):
time.sleep(1)
else:
assert False

new_children_proc = set(current_proc.children(recursive=True))
model_proc = next(iter(new_children_proc - chilren_proc))
assert len(client.list_models()) == 1

model_proc.kill()

expect_failed = False
for _ in range(5):
try:
completion = model.generate(
"Once upon a time, there was a very old computer", {"max_tokens": 64}
)
assert "text" in completion["choices"][0]
break
except Exception:
time.sleep(1)
else:
expect_failed = True
assert expect_failed
65 changes: 53 additions & 12 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@


DEFAULT_NODE_HEARTBEAT_INTERVAL = 5
MODEL_ACTOR_AUTO_RECOVER_LIMIT: Optional[int]
_MODEL_ACTOR_AUTO_RECOVER_LIMIT = os.getenv("XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT")
if _MODEL_ACTOR_AUTO_RECOVER_LIMIT is not None:
MODEL_ACTOR_AUTO_RECOVER_LIMIT = int(_MODEL_ACTOR_AUTO_RECOVER_LIMIT)
else:
MODEL_ACTOR_AUTO_RECOVER_LIMIT = None


class WorkerActor(xo.StatelessActor):
Expand All @@ -58,20 +64,47 @@ def __init__(
self._gpu_to_model_uid: Dict[int, str] = {}
self._gpu_to_embedding_model_uids: Dict[int, Set[str]] = defaultdict(set)
self._model_uid_to_addr: Dict[str, str] = {}
self._model_uid_to_recover_count: Dict[str, int] = {}
self._model_uid_to_launch_args: Dict[str, Dict] = {}

self._lock = asyncio.Lock()

async def recover_sub_pool(self, address):
logger.warning("Process %s is down, create model.", address)
logger.warning("Process %s is down.", address)
# Xoscar does not remove the address from sub_processes.
try:
await self._main_pool.remove_sub_pool(address)
except Exception:
pass
for model_uid, addr in self._model_uid_to_addr.items():
if addr == address:
launch_args = self._model_uid_to_launch_args.get(model_uid)
try:
await self.terminate_model(model_uid)
except Exception:
pass
await self.launch_builtin_model(**launch_args)
if launch_args is None:
logger.warning(
"Not recreate model because the it is down during launch."
)
else:
recover_count = self._model_uid_to_recover_count.get(model_uid)
try:
await self.terminate_model(model_uid)
except Exception:
pass
if recover_count is not None:
if recover_count > 0:
logger.warning(
"Recreating model actor %s, remain %s times ...",
model_uid,
recover_count - 1,
)
self._model_uid_to_recover_count[model_uid] = (
recover_count - 1
)
await self.launch_builtin_model(**launch_args)
else:
logger.warning("Stop recreating model actor.")
else:
logger.warning("Recreating model actor %s ...", model_uid)
await self.launch_builtin_model(**launch_args)
break

@classmethod
Expand Down Expand Up @@ -414,6 +447,9 @@ async def launch_builtin_model(
self._model_uid_to_model[model_uid] = model_ref
self._model_uid_to_model_spec[model_uid] = model_description
self._model_uid_to_addr[model_uid] = subpool_address
self._model_uid_to_recover_count.setdefault(
model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
)
self._model_uid_to_launch_args[model_uid] = launch_args

# update status to READY
Expand All @@ -431,7 +467,7 @@ async def terminate_model(self, model_uid: str):
)
model_ref = self._model_uid_to_model.get(model_uid, None)
if model_ref is None:
raise ValueError(f"Model not found in the model list, uid: {model_uid}")
logger.debug("Model not found, uid: %s", model_uid)

try:
await xo.destroy_actor(model_ref)
Expand All @@ -442,12 +478,17 @@ async def terminate_model(self, model_uid: str):
try:
subpool_address = self._model_uid_to_addr[model_uid]
await self._main_pool.remove_sub_pool(subpool_address)
except Exception as e:
logger.debug(
"Remove sub pool failed, model uid: %s, error: %s", model_uid, e
)
finally:
del self._model_uid_to_model[model_uid]
del self._model_uid_to_model_spec[model_uid]
self._model_uid_to_model.pop(model_uid, None)
self._model_uid_to_model_spec.pop(model_uid, None)
self.release_devices(model_uid)
del self._model_uid_to_addr[model_uid]
del self._model_uid_to_launch_args[model_uid]
self._model_uid_to_addr.pop(model_uid, None)
self._model_uid_to_recover_count.pop(model_uid, None)
self._model_uid_to_launch_args.pop(model_uid, None)
await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATED.name}
)
Expand All @@ -465,7 +506,7 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]:
def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
model_ref = self._model_uid_to_model.get(model_uid, None)
if model_ref is None:
raise ValueError(f"Model not found in the model list, uid: {model_uid}")
raise ValueError(f"Model not found, uid: {model_uid}")
return model_ref

@log_sync(logger=logger)
Expand Down

0 comments on commit 45f1d49

Please sign in to comment.