@@ -173,14 +173,20 @@ def __init__(self,
173
173
if self .enforce_eager is None :
174
174
self .enforce_eager = False
175
175
176
- if (not self .disable_sliding_window
177
- and self .hf_text_config .model_type == "gemma2"
178
- and self .hf_text_config .sliding_window is not None ):
176
+ sliding_window = getattr (self .hf_text_config , "sliding_window" , None )
177
+ has_interleaved_attention = (sliding_window is not None ) and (
178
+ isinstance (sliding_window , list ) or
179
+ (self .hf_text_config .model_type in ["gemma2" ]))
180
+
181
+ if (not self .disable_sliding_window and has_interleaved_attention ):
182
+ sliding_window_len_min = get_min_sliding_window (
183
+ self .hf_text_config .sliding_window )
184
+
179
185
print_warning_once (
180
- "Gemma 2 uses sliding window attention for every odd layer , "
186
+ f" { self . hf_text_config . model_type } has interleaved attention, "
181
187
"which is currently not supported by vLLM. Disabling sliding "
182
188
"window and capping the max length to the sliding window size "
183
- f"({ self . hf_text_config . sliding_window } )." )
189
+ f"({ sliding_window_len_min } )." )
184
190
self .disable_sliding_window = True
185
191
186
192
self .max_model_len = _get_and_verify_max_len (
@@ -431,7 +437,8 @@ def verify_with_parallel_config(
431
437
"pipeline parallelism currently. Disabling it." )
432
438
self .use_async_output_proc = False
433
439
434
- def get_hf_config_sliding_window (self ) -> Optional [int ]:
440
+ def get_hf_config_sliding_window (
441
+ self ) -> Union [Optional [int ], List [Optional [int ]]]:
435
442
"""Get the sliding window size, or None if disabled."""
436
443
437
444
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -442,7 +449,7 @@ def get_hf_config_sliding_window(self) -> Optional[int]:
442
449
return None
443
450
return getattr (self .hf_text_config , "sliding_window" , None )
444
451
445
- def get_sliding_window (self ) -> Optional [int ]:
452
+ def get_sliding_window (self ) -> Optional [Union [ int , List [ Optional [ int ]]] ]:
446
453
"""Get the sliding window size, or None if disabled.
447
454
"""
448
455
# If user disables sliding window, return None.
@@ -1689,7 +1696,7 @@ def _get_and_verify_max_len(
1689
1696
hf_config : PretrainedConfig ,
1690
1697
max_model_len : Optional [int ],
1691
1698
disable_sliding_window : bool ,
1692
- sliding_window_len : Optional [int ],
1699
+ sliding_window_len : Optional [Union [ int , List [ Optional [ int ]]] ],
1693
1700
spec_target_max_model_len : Optional [int ] = None ,
1694
1701
) -> int :
1695
1702
"""Get and verify the model's maximum length."""
@@ -1722,9 +1729,12 @@ def _get_and_verify_max_len(
1722
1729
# If sliding window is manually disabled, max_length should be less
1723
1730
# than the sliding window length in the model config.
1724
1731
if disable_sliding_window and sliding_window_len is not None :
1732
+
1733
+ sliding_window_len_min = get_min_sliding_window (sliding_window_len )
1725
1734
max_len_key = "sliding_window" \
1726
- if sliding_window_len < derived_max_model_len else max_len_key
1727
- derived_max_model_len = min (derived_max_model_len , sliding_window_len )
1735
+ if sliding_window_len_min < derived_max_model_len else max_len_key
1736
+ derived_max_model_len = min (derived_max_model_len ,
1737
+ sliding_window_len_min )
1728
1738
1729
1739
# If none of the keys were found in the config, use a default and
1730
1740
# log a warning.
@@ -1805,6 +1815,14 @@ def _get_and_verify_max_len(
1805
1815
return int (max_model_len )
1806
1816
1807
1817
1818
+ def get_min_sliding_window (
1819
+ sliding_window : Union [int , List [Optional [int ]]]) -> int :
1820
+ if isinstance (sliding_window , list ):
1821
+ return min (s for s in sliding_window if s is not None )
1822
+
1823
+ return sliding_window
1824
+
1825
+
1808
1826
def get_served_model_name (model : str ,
1809
1827
served_model_name : Optional [Union [str , List [str ]]]):
1810
1828
"""
0 commit comments