diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index d925488df..c6f3bb193 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -104,6 +104,7 @@ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): def clear(self): """Clear the slot and mark it as available.""" self._state = Slot.State.EMPTY + self._batch_id = None self._request_id = None self._inputs = "" self._generation_config = None @@ -124,6 +125,10 @@ def id(self) -> int: def state(self) -> "Slot.State": return self._state + @property + def batch_id(self) -> int: + return self._batch_id + @property def request_id(self) -> int: return self._request_id @@ -140,7 +145,7 @@ def generation_config(self) -> GenerationConfig: def generated_tokens(self) -> int: return self._generated_tokens - def assign(self, request: Request, generation_config: GenerationConfig): + def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig): """Assign a request to a slot. Args: @@ -150,6 +155,7 @@ def assign(self, request: Request, generation_config: GenerationConfig): The base generation config (might be modified by the request generation parameters). """ self._state = Slot.State.READY + self._batch_id = batch_id self._request_id = request.id self._inputs = request.inputs self._generation_config = copy.deepcopy(generation_config) @@ -310,6 +316,7 @@ def __init__( self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.batch_id = 0 @property def info(self) -> InfoResponse: @@ -366,7 +373,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: new_slots = [] for request in batch.requests: slot = empty_slots.pop() - slot.assign(request, self.model.generation_config) + slot.assign(self.batch_id, request, self.model.generation_config) new_slots.append(slot) logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") if self.rebuild_cache_on_prefill: @@ -415,7 +422,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids) logits = self.model(**model_inputs)[0] - generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids) + generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids) + self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() @@ -524,29 +532,33 @@ def _cached_batch(self, batch_id: int, request_ids: List): max_tokens = size * self.model.max_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) - def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: + def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch Args: batch_id (`int`): The id of a cached batch. - request_ids(`List[int]`): + keep_ids(`List[int]`): The list of requests that must be kept. Return: A `CachedBatch` containing the pending requests. """ - self._clear(request_ids) - return self._cached_batch(batch_id, request_ids) - - def clear(self): - """Remove all requests from the generator""" - return self._clear([]) - - def _clear(self, request_ids: List): + keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids] + self._clear(keep_slot_ids) + return self._cached_batch(batch_id, keep_request_ids) + + def clear(self, batch_id: Optional[int] = None): + """Remove a subset or all requests from the generator""" + keep_ids = [] + if batch_id is not None: + keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id] + return self._clear(keep_ids) + + def _clear(self, keep_slot_ids: List): for slot in self.slots: - if slot.state != Slot.State.EMPTY and slot.request_id not in request_ids: - logger.debug(f"Removing request {slot.request_id}") + if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: + logger.info(f"Removing slot {slot.id} with request {slot.request_id}") slot.clear() @classmethod diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py index f8bc8ff7b..8eb2592d6 100644 --- a/text-generation-inference/server/text_generation_server/server.py +++ b/text-generation-inference/server/text_generation_server/server.py @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context): async def ClearCache(self, request, context): if request.HasField("id"): - logger.warning(f"Clearing all batches instead of batch {request.id} only.") - self.generator.clear() + self.generator.clear(request.id) + else: + self.generator.clear() return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context):