Skip to content

Commit 51f86bf

Browse files
[mypy][CI/Build] Fix mypy errors (#7929)
1 parent c166e7e commit 51f86bf

File tree

5 files changed

+24
-9
lines changed

5 files changed

+24
-9
lines changed

tests/samplers/test_sampler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def run_test_case(*, expected_penalization: List[bool],
418418
prompt_len = seq_data.get_prompt_len()
419419
seq_lens.append(prompt_len)
420420

421+
assert sgm.sampling_params is not None
421422
if sgm.sampling_params.prompt_logprobs:
422423
# with prompt_logprobs each token in the prompt has a row in
423424
# logits
@@ -533,6 +534,8 @@ def test_sampling():
533534

534535
for i, (sequence_output, metadata) in enumerate(
535536
zip(sampler_output, seq_group_metadata_list)):
537+
assert metadata.sampling_params is not None
538+
536539
if metadata.sampling_params.use_beam_search:
537540
continue
538541

@@ -550,6 +553,8 @@ def test_sampling():
550553
assert expected_tokens_item is not None
551554

552555
for n, nth_output in enumerate(sequence_output.samples):
556+
assert metadata.sampling_params is not None
557+
553558
if (metadata.sampling_params.temperature == 0
554559
or metadata.sampling_params.seed is not None):
555560
# Ensure exact matches for greedy or random with seed

vllm/assets/audio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
1919

2020
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
2121
s3_prefix=ASSET_DIR)
22-
return librosa.load(audio_path, sr=None)
22+
y, sr = librosa.load(audio_path, sr=None)
23+
assert isinstance(sr, int)
24+
return y, sr
2325

2426
@property
2527
def url(self) -> str:

vllm/entrypoints/openai/rpc/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(self, rpc_path: str):
101101
# Maximum number of sockets that can be opened (typically 65536).
102102
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
103103
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
104+
assert isinstance(socket_limit, int)
104105
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
105106
raise ValueError(
106107
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
@@ -141,8 +142,8 @@ async def run_proxy(self, socket_from, socket_to):
141142
poller.register(socket_from, zmq.constants.POLLIN)
142143
poller.register(socket_to, zmq.constants.POLLIN)
143144
while True:
144-
events = await poller.poll()
145-
events = dict(events)
145+
events_lst = await poller.poll()
146+
events = dict(events_lst)
146147
if socket_from in events:
147148
identity, msg = await socket_from.recv_multipart()
148149
await socket_to.send_multipart([identity, msg])

vllm/multimodal/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.config import ModelConfig
1515
from vllm.inputs import InputContext
1616
from vllm.logger import init_logger
17-
from vllm.utils import json_map_leaves
17+
from vllm.utils import JSONTree, is_list_of, json_map_leaves
1818

1919
logger = init_logger(__name__)
2020

@@ -54,13 +54,14 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
5454
return nested_tensors
5555

5656
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
57-
if any(isinstance(t, list) for t in stacked):
57+
if is_list_of(stacked, list):
58+
# Do not stack nested lists
5859
return stacked
5960

6061
tensors_ = cast(List[torch.Tensor], stacked)
6162
if any(t.shape != tensors_[0].shape for t in tensors_):
6263
# The tensors have incompatible shapes and can't be stacked.
63-
return tensors_
64+
return stacked
6465

6566
return torch.stack(tensors_)
6667

@@ -101,8 +102,14 @@ def as_kwargs(
101102
*,
102103
device: torch.types.Device,
103104
) -> BatchedTensorInputs:
104-
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
105-
batched_inputs)
105+
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
106+
107+
json_mapped = json_map_leaves(
108+
lambda x: x.to(device, non_blocking=True),
109+
json_inputs,
110+
)
111+
112+
return cast(BatchedTensorInputs, json_mapped)
106113

107114

108115
_T = TypeVar("_T")

vllm/sequence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ class SequenceGroupMetadata(
883883
request_id: str
884884
is_prompt: bool
885885
seq_data: Dict[int, SequenceData]
886-
sampling_params: SamplingParams
886+
sampling_params: Optional[SamplingParams]
887887
block_tables: Dict[int, List[int]]
888888
do_sample: bool = True
889889
pooling_params: Optional[PoolingParams] = None

0 commit comments

Comments
 (0)