Skip to content

Commit

Permalink
format: make mypy happy (opendatahub-io#24)
Browse files Browse the repository at this point in the history
`format.sh` now has mypy checks after pulling in upstream changes. This
PR makes the mypy suggested modifications to our code.

---------

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 authored May 8, 2024
1 parent 4c758aa commit 2caabff
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
23 changes: 12 additions & 11 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.pb import generation_pb2_grpc
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
BatchedGenerationResponse,
Expand Down Expand Up @@ -54,15 +54,15 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
if not isinstance(e, AbortError):
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
logger.exception(f"{func.__name__} caused GPU OOM error")
logger.exception("%s caused GPU OOM error", func.__name__)
service_metrics.count_request_failure(FailureReasonLabel.OOM)
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
else:
if "generate" in func.__name__.lower():
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
else:
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
logger.exception(f"{func.__name__} failed")
logger.exception("%s failed", func.__name__)
raise e


Expand Down Expand Up @@ -295,7 +295,7 @@ def _convert_output(self,
text=output.text[text_start_offset:],
generated_token_count=len(output.token_ids),
stop_reason=stop_reason,
stop_sequence=stop_sequence,
stop_sequence=stop_sequence if stop_sequence else '',
)

if resp_options.generated_tokens:
Expand Down Expand Up @@ -413,7 +413,8 @@ async def _validate_and_convert_params(

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool) -> Tuple['StopReason', str]:
time_limit_reached: bool
) -> Tuple[StopReason.ValueType, Optional[str]]:
finish_reason = output.finish_reason
stop_sequence = None
if finish_reason is None:
Expand All @@ -433,20 +434,20 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
stop_sequence = stop_str_or_tok
else:
logger.warning(
f"Unexpected stop_reason type: {type(stop_str_or_tok)}"
"Unexpected stop_reason type: %s", type(stop_str_or_tok)
)
elif finish_reason == "abort":
stop_reason = StopReason.CANCELLED
else:
logger.warning(f"Unrecognized finish_reason: {finish_reason}")
logger.warning("Unrecognized finish_reason: %s", finish_reason)
stop_reason = StopReason.CANCELLED

return stop_reason, stop_sequence

def _convert_tokens(
self,
token_ids: list[int],
logprobs_list: Optional[list[Dict[int, Logprob]]],
token_ids: List[int],
logprobs_list: Optional[List[Dict[int, Logprob]]],
include_logprobs: bool,
include_ranks: bool,
top_n_tokens: int,
Expand Down Expand Up @@ -499,7 +500,7 @@ async def _validate_prompt_and_tokenize(
# "max_length": truncate_input_tokens} \
# if truncate_input_tokens is not None else {
# "truncation": True, "max_length": max_model_len + 1}
tokenize_kwargs = {}
tokenize_kwargs: Dict[str, Any] = {}

input_ids = await self.tokenizer_group.encode_async(
prompt, **tokenize_kwargs)
Expand Down Expand Up @@ -661,6 +662,6 @@ async def start_grpc_server(engine: AsyncLLMEngine,
server.add_insecure_port(listen_on)

await server.start()
logger.info(f"gRPC Server started at {listen_on}")
logger.info("gRPC Server started at %s", listen_on)

return server
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import asyncio
import importlib
import inspect
Expand Down
8 changes: 4 additions & 4 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
if args.max_batch_size is not None:
# Existing MAX_BATCH_SIZE settings in TGIS configs may not necessarily
# be best for vLLM so we'll just log a warning for now
logger.warn(
f"max_batch_size is set to {args.max_batch_size} but will be "
f"ignored for now. max_num_seqs can be used if this is still "
f"needed.")
logger.warning(
"max_batch_size is set to %d but will be ignored for now."
"max_num_seqs can be used if this is still needed.",
args.max_batch_size)

if args.tls_cert_path:
args.ssl_certfile = args.tls_cert_path
Expand Down
9 changes: 4 additions & 5 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str,
level = logging.WARN
else:
level = logging.INFO
logger.log(
level, f"{span_str}: {kind_log} generated "
f"{response.generated_token_count} tokens before "
f"{stop_reason_str}, output {output_len} chars: "
f"{short_output}")
logger.log(level,
"%s: %s generated %d tokens before %s, output %d chars: %s",
span_str, kind_log, response.generated_token_count,
stop_reason_str, output_len, short_output)


def _truncate(text: str, len_: int) -> bytes:
Expand Down

0 comments on commit 2caabff

Please sign in to comment.