Skip to content

Commit

Permalink
[mypy][CI/Build] Fix mypy errors (vllm-project#7929)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
DarkLight1337 authored and Alvant committed Oct 26, 2024
1 parent b859703 commit 5cf2382
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
5 changes: 5 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def run_test_case(*, expected_penalization: List[bool],
prompt_len = seq_data.get_prompt_len()
seq_lens.append(prompt_len)

assert sgm.sampling_params is not None
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
Expand Down Expand Up @@ -533,6 +534,8 @@ def test_sampling():

for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None

if metadata.sampling_params.use_beam_search:
continue

Expand All @@ -550,6 +553,8 @@ def test_sampling():
assert expected_tokens_item is not None

for n, nth_output in enumerate(sequence_output.samples):
assert metadata.sampling_params is not None

if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
Expand Down
4 changes: 3 additions & 1 deletion vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:

audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr

@property
def url(self) -> str:
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, rpc_path: str):
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
assert isinstance(socket_limit, int)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
Expand Down Expand Up @@ -141,8 +142,8 @@ async def run_proxy(self, socket_from, socket_to):
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
events_lst = await poller.poll()
events = dict(events_lst)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
Expand Down
17 changes: 12 additions & 5 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import json_map_leaves
from vllm.utils import JSONTree, is_list_of, json_map_leaves

logger = init_logger(__name__)

Expand Down Expand Up @@ -54,13 +54,14 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
return nested_tensors

stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if any(isinstance(t, list) for t in stacked):
if is_list_of(stacked, list):
# Do not stack nested lists
return stacked

tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return stacked

return torch.stack(tensors_)

Expand Down Expand Up @@ -101,8 +102,14 @@ def as_kwargs(
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
batched_inputs)
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)

return cast(BatchedTensorInputs, json_mapped)


_T = TypeVar("_T")
Expand Down
2 changes: 1 addition & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ class SequenceGroupMetadata(
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
Expand Down

0 comments on commit 5cf2382

Please sign in to comment.