Skip to content

Commit

Permalink
fix(tgi): allow clearing requests from a single batch
Browse files Browse the repository at this point in the history
When all requests from a prefill batch are cancelled, the router will not send
a filter request, but rather a clear cache request with the batch_id.
We previously ignored that value and cleared everything.
  • Loading branch information
dacorvo committed May 24, 2024
1 parent 98422bc commit fcce2f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
def clear(self):
"""Clear the slot and mark it as available."""
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
Expand All @@ -124,6 +125,10 @@ def id(self) -> int:
def state(self) -> "Slot.State":
return self._state

@property
def batch_id(self) -> int:
return self._batch_id

@property
def request_id(self) -> int:
return self._request_id
Expand All @@ -140,7 +145,7 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

def assign(self, request: Request, generation_config: GenerationConfig):
def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.
Args:
Expand All @@ -150,6 +155,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
The base generation config (might be modified by the request generation parameters).
"""
self._state = Slot.State.READY
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
self._generation_config = copy.deepcopy(generation_config)
Expand Down Expand Up @@ -310,6 +316,7 @@ def __init__(
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.batch_id = 0

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -366,7 +373,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
if self.rebuild_cache_on_prefill:
Expand Down Expand Up @@ -415,7 +422,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids)
logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids)
generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
Expand Down Expand Up @@ -524,29 +532,33 @@ def _cached_batch(self, batch_id: int, request_ids: List):
max_tokens = size * self.model.max_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch
Args:
batch_id (`int`):
The id of a cached batch.
request_ids(`List[int]`):
keep_ids(`List[int]`):
The list of requests that must be kept.
Return:
A `CachedBatch` containing the pending requests.
"""
self._clear(request_ids)
return self._cached_batch(batch_id, request_ids)

def clear(self):
"""Remove all requests from the generator"""
return self._clear([])

def _clear(self, request_ids: List):
keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
self._clear(keep_slot_ids)
return self._cached_batch(batch_id, keep_request_ids)

def clear(self, batch_id: Optional[int] = None):
"""Remove a subset or all requests from the generator"""
keep_ids = []
if batch_id is not None:
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
return self._clear(keep_ids)

def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.request_id not in request_ids:
logger.debug(f"Removing request {slot.request_id}")
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.info(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context):

async def ClearCache(self, request, context):
if request.HasField("id"):
logger.warning(f"Clearing all batches instead of batch {request.id} only.")
self.generator.clear()
self.generator.clear(request.id)
else:
self.generator.clear()
return generate_pb2.ClearCacheResponse()

async def FilterBatch(self, request, context):
Expand Down

0 comments on commit fcce2f8

Please sign in to comment.