@@ -173,14 +173,20 @@ def __init__(self,
173
173
if self .enforce_eager is None :
174
174
self .enforce_eager = False
175
175
176
- if (not self .disable_sliding_window
177
- and self .hf_text_config .model_type == "gemma2"
178
- and self .hf_text_config .sliding_window is not None ):
176
+ sliding_window = getattr (self .hf_text_config , "sliding_window" , None )
177
+ has_interleaved_attention = (sliding_window is not None ) and (
178
+ isinstance (sliding_window , list ) or
179
+ (self .hf_text_config .model_type in ["gemma2" ]))
180
+
181
+ if (not self .disable_sliding_window and has_interleaved_attention ):
182
+ sliding_window_len_min = get_min_sliding_window (
183
+ self .hf_text_config .sliding_window )
184
+
179
185
print_warning_once (
180
- "Gemma 2 uses sliding window attention for every odd layer , "
186
+ f" { self . hf_text_config . model_type } has interleaved attention, "
181
187
"which is currently not supported by vLLM. Disabling sliding "
182
188
"window and capping the max length to the sliding window size "
183
- f"({ self . hf_text_config . sliding_window } )." )
189
+ f"({ sliding_window_len_min } )." )
184
190
self .disable_sliding_window = True
185
191
186
192
self .max_model_len = _get_and_verify_max_len (
@@ -431,7 +437,8 @@ def verify_with_parallel_config(
431
437
"pipeline parallelism currently. Disabling it." )
432
438
self .use_async_output_proc = False
433
439
434
- def get_hf_config_sliding_window (self ) -> Optional [int ]:
440
+ def get_hf_config_sliding_window (
441
+ self ) -> Union [Optional [int ], List [Optional [int ]]]:
435
442
"""Get the sliding window size, or None if disabled."""
436
443
437
444
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -442,8 +449,9 @@ def get_hf_config_sliding_window(self) -> Optional[int]:
442
449
return None
443
450
return getattr (self .hf_text_config , "sliding_window" , None )
444
451
445
- def get_sliding_window (self ) -> Optional [int ]:
446
- """Get the sliding window size, or None if disabled."""
452
+ def get_sliding_window (self ) -> Optional [Union [int , List [Optional [int ]]]]:
453
+ """Get the sliding window size, or None if disabled.
454
+ """
447
455
# If user disables sliding window, return None.
448
456
if self .disable_sliding_window :
449
457
return None
@@ -1717,7 +1725,7 @@ def _get_and_verify_max_len(
1717
1725
hf_config : PretrainedConfig ,
1718
1726
max_model_len : Optional [int ],
1719
1727
disable_sliding_window : bool ,
1720
- sliding_window_len : Optional [int ],
1728
+ sliding_window_len : Optional [Union [ int , List [ Optional [ int ]]] ],
1721
1729
spec_target_max_model_len : Optional [int ] = None ,
1722
1730
) -> int :
1723
1731
"""Get and verify the model's maximum length."""
@@ -1750,10 +1758,12 @@ def _get_and_verify_max_len(
1750
1758
# If sliding window is manually disabled, max_length should be less
1751
1759
# than the sliding window length in the model config.
1752
1760
if disable_sliding_window and sliding_window_len is not None :
1753
- max_len_key = ("sliding_window"
1754
- if sliding_window_len < derived_max_model_len else
1755
- max_len_key )
1756
- derived_max_model_len = min (derived_max_model_len , sliding_window_len )
1761
+
1762
+ sliding_window_len_min = get_min_sliding_window (sliding_window_len )
1763
+ max_len_key = "sliding_window" \
1764
+ if sliding_window_len_min < derived_max_model_len else max_len_key
1765
+ derived_max_model_len = min (derived_max_model_len ,
1766
+ sliding_window_len_min )
1757
1767
1758
1768
# If none of the keys were found in the config, use a default and
1759
1769
# log a warning.
@@ -1836,6 +1846,14 @@ def _get_and_verify_max_len(
1836
1846
return int (max_model_len )
1837
1847
1838
1848
1849
+ def get_min_sliding_window (
1850
+ sliding_window : Union [int , List [Optional [int ]]]) -> int :
1851
+ if isinstance (sliding_window , list ):
1852
+ return min (s for s in sliding_window if s is not None )
1853
+
1854
+ return sliding_window
1855
+
1856
+
1839
1857
def get_served_model_name (model : str ,
1840
1858
served_model_name : Optional [Union [str , List [str ]]]):
1841
1859
"""
0 commit comments