Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tgi correct clear implementation #609

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
def clear(self):
"""Clear the slot and mark it as available."""
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
Expand All @@ -124,6 +125,10 @@ def id(self) -> int:
def state(self) -> "Slot.State":
return self._state

@property
def batch_id(self) -> int:
return self._batch_id

@property
def request_id(self) -> int:
return self._request_id
Expand All @@ -140,7 +145,7 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

def assign(self, request: Request, generation_config: GenerationConfig):
def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.

Args:
Expand All @@ -150,6 +155,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
The base generation config (might be modified by the request generation parameters).
"""
self._state = Slot.State.READY
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
self._generation_config = copy.deepcopy(generation_config)
Expand All @@ -169,6 +175,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self.seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
Expand Down Expand Up @@ -197,8 +204,9 @@ def pause(self, reset_on_pause: bool):
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Subtract the number of cached tokens from the maximum number of tokens
self._generation_config.max_new_tokens -= self._generated_tokens
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
self._state = Slot.State.PAUSE

def resume(self):
Expand Down Expand Up @@ -308,6 +316,7 @@ def __init__(
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.batch_id = 0

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -364,7 +373,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
if self.rebuild_cache_on_prefill:
Expand Down Expand Up @@ -413,7 +422,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids)
logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids)
generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
Expand Down Expand Up @@ -522,29 +532,33 @@ def _cached_batch(self, batch_id: int, request_ids: List):
max_tokens = size * self.model.max_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch

Args:
batch_id (`int`):
The id of a cached batch.
request_ids(`List[int]`):
keep_ids(`List[int]`):
The list of requests that must be kept.

Return:
A `CachedBatch` containing the pending requests.
"""
self._clear(request_ids)
return self._cached_batch(batch_id, request_ids)

def clear(self):
"""Remove all requests from the generator"""
return self._clear([])

def _clear(self, request_ids: List):
keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
self._clear(keep_slot_ids)
return self._cached_batch(batch_id, keep_request_ids)

def clear(self, batch_id: Optional[int] = None):
"""Remove a subset or all requests from the generator"""
keep_ids = []
if batch_id is not None:
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
return self._clear(keep_ids)

def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.request_id not in request_ids:
logger.debug(f"Removing request {slot.request_id}")
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.info(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context):

async def ClearCache(self, request, context):
if request.HasField("id"):
logger.warning(f"Clearing all batches instead of batch {request.id} only.")
self.generator.clear()
self.generator.clear(request.id)
else:
self.generator.clear()
return generate_pb2.ClearCacheResponse()

async def FilterBatch(self, request, context):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def tokenizer(request):
def test_decode_streaming(tokenizer, input_text, generated_text):
slot = Slot(0, tokenizer)
request = Request(id=0, inputs=input_text)
slot.assign(request, GenerationConfig())
slot.assign(0, request, GenerationConfig())
assert slot.cached_text == input_text

inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/tgi-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ ${SCRIPT_DIR}/tgi_env.py $@

source $ENV_FILEPATH

text-generation-launcher $@
exec text-generation-launcher $@
Loading