Skip to content

Commit

Permalink
Tgi correct clear implementation (#609)
Browse files Browse the repository at this point in the history
* fix(tgi): allow wrapper script to catch SIGTERM

* fix(tgi): bogus max_new_tokens with static batching

* fix(tgi): allow clearing requests from a single batch

When all requests from a prefill batch are cancelled, the router will not send
a filter request, but rather a clear cache request with the batch_id.
We previously ignored that value and cleared everything.
  • Loading branch information
dacorvo authored May 27, 2024
1 parent 639c17a commit 4a21d96
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
def clear(self):
"""Clear the slot and mark it as available."""
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
Expand All @@ -124,6 +125,10 @@ def id(self) -> int:
def state(self) -> "Slot.State":
return self._state

@property
def batch_id(self) -> int:
return self._batch_id

@property
def request_id(self) -> int:
return self._request_id
Expand All @@ -140,7 +145,7 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

def assign(self, request: Request, generation_config: GenerationConfig):
def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.
Args:
Expand All @@ -150,6 +155,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
The base generation config (might be modified by the request generation parameters).
"""
self._state = Slot.State.READY
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
self._generation_config = copy.deepcopy(generation_config)
Expand All @@ -169,6 +175,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self.seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
Expand Down Expand Up @@ -197,8 +204,9 @@ def pause(self, reset_on_pause: bool):
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Subtract the number of cached tokens from the maximum number of tokens
self._generation_config.max_new_tokens -= self._generated_tokens
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
self._state = Slot.State.PAUSE

def resume(self):
Expand Down Expand Up @@ -308,6 +316,7 @@ def __init__(
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.batch_id = 0

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -364,7 +373,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
if self.rebuild_cache_on_prefill:
Expand Down Expand Up @@ -413,7 +422,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids)
logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids)
generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
Expand Down Expand Up @@ -522,29 +532,33 @@ def _cached_batch(self, batch_id: int, request_ids: List):
max_tokens = size * self.model.max_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch
Args:
batch_id (`int`):
The id of a cached batch.
request_ids(`List[int]`):
keep_ids(`List[int]`):
The list of requests that must be kept.
Return:
A `CachedBatch` containing the pending requests.
"""
self._clear(request_ids)
return self._cached_batch(batch_id, request_ids)

def clear(self):
"""Remove all requests from the generator"""
return self._clear([])

def _clear(self, request_ids: List):
keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
self._clear(keep_slot_ids)
return self._cached_batch(batch_id, keep_request_ids)

def clear(self, batch_id: Optional[int] = None):
"""Remove a subset or all requests from the generator"""
keep_ids = []
if batch_id is not None:
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
return self._clear(keep_ids)

def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.request_id not in request_ids:
logger.debug(f"Removing request {slot.request_id}")
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.info(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context):

async def ClearCache(self, request, context):
if request.HasField("id"):
logger.warning(f"Clearing all batches instead of batch {request.id} only.")
self.generator.clear()
self.generator.clear(request.id)
else:
self.generator.clear()
return generate_pb2.ClearCacheResponse()

async def FilterBatch(self, request, context):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def tokenizer(request):
def test_decode_streaming(tokenizer, input_text, generated_text):
slot = Slot(0, tokenizer)
request = Request(id=0, inputs=input_text)
slot.assign(request, GenerationConfig())
slot.assign(0, request, GenerationConfig())
assert slot.cached_text == input_text

inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/tgi-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ ${SCRIPT_DIR}/tgi_env.py $@

source $ENV_FILEPATH

text-generation-launcher $@
exec text-generation-launcher $@

0 comments on commit 4a21d96

Please sign in to comment.