Skip to content

Commit 4e25c7c

Browse files
patrickvonplatenLeiWang1999
authored andcommitted
Support mistral interleaved attn (vllm-project#9414)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent f5faf5f commit 4e25c7c

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

vllm/config.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,20 @@ def __init__(self,
173173
if self.enforce_eager is None:
174174
self.enforce_eager = False
175175

176-
if (not self.disable_sliding_window
177-
and self.hf_text_config.model_type == "gemma2"
178-
and self.hf_text_config.sliding_window is not None):
176+
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
177+
has_interleaved_attention = (sliding_window is not None) and (
178+
isinstance(sliding_window, list) or
179+
(self.hf_text_config.model_type in ["gemma2"]))
180+
181+
if (not self.disable_sliding_window and has_interleaved_attention):
182+
sliding_window_len_min = get_min_sliding_window(
183+
self.hf_text_config.sliding_window)
184+
179185
print_warning_once(
180-
"Gemma 2 uses sliding window attention for every odd layer, "
186+
f"{self.hf_text_config.model_type} has interleaved attention, "
181187
"which is currently not supported by vLLM. Disabling sliding "
182188
"window and capping the max length to the sliding window size "
183-
f"({self.hf_text_config.sliding_window}).")
189+
f"({sliding_window_len_min}).")
184190
self.disable_sliding_window = True
185191

186192
self.max_model_len = _get_and_verify_max_len(
@@ -431,7 +437,8 @@ def verify_with_parallel_config(
431437
"pipeline parallelism currently. Disabling it.")
432438
self.use_async_output_proc = False
433439

434-
def get_hf_config_sliding_window(self) -> Optional[int]:
440+
def get_hf_config_sliding_window(
441+
self) -> Union[Optional[int], List[Optional[int]]]:
435442
"""Get the sliding window size, or None if disabled."""
436443

437444
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -442,8 +449,9 @@ def get_hf_config_sliding_window(self) -> Optional[int]:
442449
return None
443450
return getattr(self.hf_text_config, "sliding_window", None)
444451

445-
def get_sliding_window(self) -> Optional[int]:
446-
"""Get the sliding window size, or None if disabled."""
452+
def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
453+
"""Get the sliding window size, or None if disabled.
454+
"""
447455
# If user disables sliding window, return None.
448456
if self.disable_sliding_window:
449457
return None
@@ -1717,7 +1725,7 @@ def _get_and_verify_max_len(
17171725
hf_config: PretrainedConfig,
17181726
max_model_len: Optional[int],
17191727
disable_sliding_window: bool,
1720-
sliding_window_len: Optional[int],
1728+
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
17211729
spec_target_max_model_len: Optional[int] = None,
17221730
) -> int:
17231731
"""Get and verify the model's maximum length."""
@@ -1750,10 +1758,12 @@ def _get_and_verify_max_len(
17501758
# If sliding window is manually disabled, max_length should be less
17511759
# than the sliding window length in the model config.
17521760
if disable_sliding_window and sliding_window_len is not None:
1753-
max_len_key = ("sliding_window"
1754-
if sliding_window_len < derived_max_model_len else
1755-
max_len_key)
1756-
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
1761+
1762+
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
1763+
max_len_key = "sliding_window" \
1764+
if sliding_window_len_min < derived_max_model_len else max_len_key
1765+
derived_max_model_len = min(derived_max_model_len,
1766+
sliding_window_len_min)
17571767

17581768
# If none of the keys were found in the config, use a default and
17591769
# log a warning.
@@ -1836,6 +1846,14 @@ def _get_and_verify_max_len(
18361846
return int(max_model_len)
18371847

18381848

1849+
def get_min_sliding_window(
1850+
sliding_window: Union[int, List[Optional[int]]]) -> int:
1851+
if isinstance(sliding_window, list):
1852+
return min(s for s in sliding_window if s is not None)
1853+
1854+
return sliding_window
1855+
1856+
18391857
def get_served_model_name(model: str,
18401858
served_model_name: Optional[Union[str, List[str]]]):
18411859
"""

0 commit comments

Comments
 (0)