Skip to content

Commit

Permalink
[python] parse only when new requests are received
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jul 10, 2024
1 parent 823563f commit b898a10
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 34 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def inference(self, inputs: Input) -> Output:
**self.input_format_args)
requests = parsed_input.requests
errors = parsed_input.errors
if len(requests) == 0:
if errors and len(parsed_input.batch) == len(errors):
for i in range(len(parsed_input.batch)):
err = errors.get(i)
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
Expand Down
21 changes: 20 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ class ParsedInput:
batch: List = field(default_factory=lambda: [])


def get_batch_start_id(batch, **kwargs):
if kwargs.get("is_rolling_batch"):
# for rolling batch, we only need to parse the new requests, as the active requests kept in cache.
rolling_batch = kwargs.get("rolling_batch")
active_requests_len = len(rolling_batch.active_requests)
batch_size = len(batch)
if batch_size > active_requests_len:
# if batch_size > active_requests_len, then new requests are received
return active_requests_len
else:
# no new requests are received, so sending batch_size, nothing will be parsed.
return batch_size
else:
# for non-rolling batch, python process only receives new requests.
return 0


def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
"""
Preprocessing function that extracts information from Input objects.
Expand All @@ -44,7 +61,9 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
kwargs.get("configs").rolling_batch)
request_id_counter = get_req_id_counter(kwargs)
for i, input_item in enumerate(batch):
start_batch_id = get_batch_start_id(batch, **kwargs)
for i in range(start_batch_id, len(batch)):
input_item = batch[i]
try:
request_id = request_id_counter.next_id(
) if request_id_counter else i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,23 @@ def translate_lmi_dist_params(self, parameters: dict):
return parameters

@stop_on_any_exception
def inference(self, requests: List[Request]) -> List:
def inference(self, new_requests: List[Request]) -> List:
"""
Adds new requests and gets output tokens from the backend.
:param requests: List of requests
:param new_requests: List of requests
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
new_requests = self.get_new_requests(requests)
self.add_new_requests(new_requests)
# step 0: register new requests to engine
for request in new_requests:
request_id = str(request.id)
params = self.translate_lmi_dist_params(request.parameters)
request_params = RequestParams(**params)
lora_request_params = get_lora_request_params(
request, self.lora_ids)
# Constructing Request in lmi-dist library
lmi_dist_request = Request(
id=request_id,
prompt=request.input_text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ def append_speculated_generations(self, generation, request, req_ids):
speculated_generation = generation.speculated_generations.dequeue()

@stop_on_any_exception
def inference(self, requests: List[Request]) -> list:
def inference(self, new_requests: List[Request]) -> list:
"""
Loads new requests and gets output tokens from all currently active requests from
the Neuron backend.
:param requests: List[Request] List of requests
:param new_requests: List[Request] List of requests
:return: generated batch decoded tokens - list of dictionaries, one for
each request, that contain output tokens and other data.
"""
new_requests = self.get_new_requests(requests)
self.add_new_requests(new_requests)
if len(new_requests) > 0:
generations = self.scheduler.prefill(new_requests)
else:
Expand Down
15 changes: 4 additions & 11 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,30 +93,23 @@ def get_tokenizer(self):
raise RuntimeError("get_tokenizer function not supported")

@abstractmethod
def inference(self, requests: List[Request]) -> List:
def inference(self, new_requests: List[Request]) -> List:
"""
Performs prefill and decode operations for the batch.
:param requests: List[Request] List of requests
:param new_requests: List[Request] List of requests
:return: generated batch decoded tokens
"""
pass

def get_new_requests(self, requests: List[Request]) -> List[Request]:
def add_new_requests(self, requests: List[Request]):
"""
Adds requests to the batch when there is availability
:param requests: List[Request] List of requests
:return: list of current active requests (including those that have just been added)
"""
total_req_len = len(self.active_requests)
batch_size = len(requests)
if batch_size > total_req_len:
for i in range(total_req_len, batch_size):
self.active_requests.append(requests[i])
return self.active_requests[total_req_len:]
self.active_requests.extend(requests)

@abstractmethod
def preprocess_requests(self, requests: List[Request]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def __init__(self, model_id_or_path: str, properties: dict,
self._init_scheduler()

@stop_on_any_exception
def inference(self, requests: List) -> List:
def inference(self, new_requests: List) -> List:
"""
Performs prefill and decode operations for the batch.
:param requests: List[Request] List of requests
:param new_requests: List[Request] List of requests
:return: generated batch decoded tokens
"""
new_requests = self.get_new_requests(requests)
self.add_new_requests(new_requests)

preprocessed_new_requests = self.preprocess_requests(new_requests)
self._prefill_and_decode(preprocessed_new_requests)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ def translate_triton_params(self, parameters: dict) -> dict:
return parameters

@stop_on_any_exception
def inference(self, requests: List[Request]) -> List:
def inference(self, new_requests: List[Request]) -> List:
"""
Loads new requests into the batch when there is availability, and gets output tokens from the backend
asynchronously.
:param requests: List[Request] List of requests
:param new_requests: List[Request] List of requests
:param input_data: List of input prompts.
:param parameters: List of settings pertaining to each request.
:param adapters: List of adapters inputs for each request in a batch
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
# add pending requests to active requests list
new_requests = self.get_new_requests(requests)
self.add_new_requests(new_requests)
# step 0: register new active requests
for request in new_requests:
param = self.translate_triton_params(request.parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ def translate_vllm_params(self, parameters: dict) -> dict:
return parameters

@stop_on_any_exception
def inference(self, requests: List[Request]) -> List:
def inference(self, new_requests: List[Request]) -> List:
"""
Adds new requests and gets output tokens from the backend.
:param requests: List[Request] List of requests
:param new_requests: List[Request] List of requests
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
new_requests = self.get_new_requests(requests)
self.add_new_requests(new_requests)
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def inference(self, inputs: Input) -> Output:

parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
if len(parsed_input.requests) == 0:
if parsed_input.errors and len(parsed_input.requests) == len(
parsed_input.errors):
for i in range(len(parsed_input.batch)):
err = parsed_input.errors.get(i)
err = {"data": "", "last": True, "code": 424, "error": err}
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def inference(self, inputs: Input) -> Output:

parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
if len(parsed_input.requests) == 0:
if parsed_input.errors and len(parsed_input.requests) == len(
parsed_input.errors):
for i in range(len(parsed_input.batch)):
err = parsed_input.errors.get(i)
outputs.add(err, key="data", batch_index=i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def reset(self):

@profile_objects
@stop_on_any_exception
def inference(self, requests: List[Request]) -> List:
new_requests = self.get_new_requests(requests)
def inference(self, new_requests: List[Request]) -> List:
self.add_new_requests(new_requests)

for new_request in new_requests:
max_len = new_request.parameters[
Expand Down Expand Up @@ -118,10 +118,10 @@ def reset(self):

@profile_objects
@stop_on_any_exception
def inference(self, requests: List[Request]):
def inference(self, new_requests: List[Request]):

if self.dead_counter.get_id() < self.dead_trigger:
self.dead_counter.next_id()
return super().inference(requests)
return super().inference(new_requests)
else:
raise RuntimeError("Death trigger triggered...")

0 comments on commit b898a10

Please sign in to comment.