From cf838d04e23b0fa5aec9a86153647788715ea745 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 23 May 2024 12:42:27 +0000 Subject: [PATCH 1/3] fix(tgi): allow wrapper script to catch SIGTERM --- text-generation-inference/tgi-entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text-generation-inference/tgi-entrypoint.sh b/text-generation-inference/tgi-entrypoint.sh index b9449d74f..b959a7958 100755 --- a/text-generation-inference/tgi-entrypoint.sh +++ b/text-generation-inference/tgi-entrypoint.sh @@ -13,4 +13,4 @@ ${SCRIPT_DIR}/tgi_env.py $@ source $ENV_FILEPATH -text-generation-launcher $@ +exec text-generation-launcher $@ From 99f1937fc224f1b0ba8d598cc1ae605995069e0a Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 23 May 2024 13:26:46 +0000 Subject: [PATCH 2/3] fix(tgi): bogus max_new_tokens with static batching --- .../server/text_generation_server/generator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 5b51b477c..d925488df 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -169,6 +169,7 @@ def assign(self, request: Request, generation_config: GenerationConfig): self.seed = request.parameters.seed # TODO: watermark self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens + self._max_new_tokens = self._generation_config.max_new_tokens # TODO: stop_sequences, ignore_eos_token def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector): @@ -197,8 +198,9 @@ def pause(self, reset_on_pause: bool): if reset_on_pause: # Drop the last token as it will be added back when resuming the slot self._generated_tokens -= 1 - # Subtract the number of cached tokens from the maximum number of tokens - self._generation_config.max_new_tokens -= self._generated_tokens + # Since generated tokens are now part of the prefill, we need to reevaluate + # max_new_tokens for the next generation + self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens self._state = Slot.State.PAUSE def resume(self): From 58eafcba6578b4da18c4a1d38fe11562b8f759e3 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 23 May 2024 15:54:26 +0000 Subject: [PATCH 3/3] fix(tgi): allow clearing requests from a single batch When all requests from a prefill batch are cancelled, the router will not send a filter request, but rather a clear cache request with the batch_id. We previously ignored that value and cleared everything. --- .../text_generation_server/generator.py | 42 ++++++++++++------- .../server/text_generation_server/server.py | 5 ++- .../tests/server/test_generator_slot.py | 2 +- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index d925488df..c6f3bb193 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -104,6 +104,7 @@ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): def clear(self): """Clear the slot and mark it as available.""" self._state = Slot.State.EMPTY + self._batch_id = None self._request_id = None self._inputs = "" self._generation_config = None @@ -124,6 +125,10 @@ def id(self) -> int: def state(self) -> "Slot.State": return self._state + @property + def batch_id(self) -> int: + return self._batch_id + @property def request_id(self) -> int: return self._request_id @@ -140,7 +145,7 @@ def generation_config(self) -> GenerationConfig: def generated_tokens(self) -> int: return self._generated_tokens - def assign(self, request: Request, generation_config: GenerationConfig): + def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig): """Assign a request to a slot. Args: @@ -150,6 +155,7 @@ def assign(self, request: Request, generation_config: GenerationConfig): The base generation config (might be modified by the request generation parameters). """ self._state = Slot.State.READY + self._batch_id = batch_id self._request_id = request.id self._inputs = request.inputs self._generation_config = copy.deepcopy(generation_config) @@ -310,6 +316,7 @@ def __init__( self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.batch_id = 0 @property def info(self) -> InfoResponse: @@ -366,7 +373,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: new_slots = [] for request in batch.requests: slot = empty_slots.pop() - slot.assign(request, self.model.generation_config) + slot.assign(self.batch_id, request, self.model.generation_config) new_slots.append(slot) logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") if self.rebuild_cache_on_prefill: @@ -415,7 +422,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids) logits = self.model(**model_inputs)[0] - generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids) + generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids) + self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() @@ -524,29 +532,33 @@ def _cached_batch(self, batch_id: int, request_ids: List): max_tokens = size * self.model.max_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) - def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: + def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch Args: batch_id (`int`): The id of a cached batch. - request_ids(`List[int]`): + keep_ids(`List[int]`): The list of requests that must be kept. Return: A `CachedBatch` containing the pending requests. """ - self._clear(request_ids) - return self._cached_batch(batch_id, request_ids) - - def clear(self): - """Remove all requests from the generator""" - return self._clear([]) - - def _clear(self, request_ids: List): + keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids] + self._clear(keep_slot_ids) + return self._cached_batch(batch_id, keep_request_ids) + + def clear(self, batch_id: Optional[int] = None): + """Remove a subset or all requests from the generator""" + keep_ids = [] + if batch_id is not None: + keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id] + return self._clear(keep_ids) + + def _clear(self, keep_slot_ids: List): for slot in self.slots: - if slot.state != Slot.State.EMPTY and slot.request_id not in request_ids: - logger.debug(f"Removing request {slot.request_id}") + if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: + logger.info(f"Removing slot {slot.id} with request {slot.request_id}") slot.clear() @classmethod diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py index f8bc8ff7b..8eb2592d6 100644 --- a/text-generation-inference/server/text_generation_server/server.py +++ b/text-generation-inference/server/text_generation_server/server.py @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context): async def ClearCache(self, request, context): if request.HasField("id"): - logger.warning(f"Clearing all batches instead of batch {request.id} only.") - self.generator.clear() + self.generator.clear(request.id) + else: + self.generator.clear() return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): diff --git a/text-generation-inference/tests/server/test_generator_slot.py b/text-generation-inference/tests/server/test_generator_slot.py index 2f243b5d4..459ee3e5b 100644 --- a/text-generation-inference/tests/server/test_generator_slot.py +++ b/text-generation-inference/tests/server/test_generator_slot.py @@ -33,7 +33,7 @@ def tokenizer(request): def test_decode_streaming(tokenizer, input_text, generated_text): slot = Slot(0, tokenizer) request = Request(id=0, inputs=input_text) - slot.assign(request, GenerationConfig()) + slot.assign(0, request, GenerationConfig()) assert slot.cached_text == input_text inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")