diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b14e3971..8071a086 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -274,6 +274,20 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value"))) return input_grad + + def _decide_window_size(self) -> int | None: + # NOTE: This is a temporal solution for qwen 2.X + # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 + # TODO: make universal per layer config + window_size = self._config.window_size + if ( + self._config.max_window_layers is not None + and self._layer_index < self._config.max_window_layers + ): + window_size = None + + return window_size + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[TransformerKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) @@ -323,13 +337,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + + window_size = self._decide_window_size() + if self._use_flash_attention: input_ = flash_attn( query, key, value, dropout_p=self._config.attention_dropout if self.training else 0.0, - window_size=self._config.window_size, + window_size=window_size, causal=True, generator=self._tensor_space.distributed.tp_generator, softmax_scale=self._softmax_scale, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 81d58a5a..1b4e7749 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -443,6 +443,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto mlp_recompute_level: MLPRecomputeLevel = Field( default=MLPRecomputeLevel.none, @@ -571,4 +577,10 @@ def _validate(self) -> None: Assert.geq(scale, 0) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + # Config parameter `window_size` only can be used with flash attention + if not use_flash_attention: + Assert.is_(self.window_size, None) + + return use_flash_attention diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 00000000..db856787 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,22 @@ +import unittest.mock +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerConfig + + +def test_decide_window_size(): + attention = unittest.mock.Mock(spec=Attention) + attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method + + # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) + attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._layer_index = 2 + assert attention._decide_window_size() == 512 + + # Arrange - Case 2: window_size is None (layer_index < max_window_layers) + attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._layer_index = 1 + assert attention._decide_window_size() is None + + # Arrange - Case 3: max_window_layers is None (always return window_size) + attention._config = TransformerConfig(window_size=512, max_window_layers=None) + assert attention._decide_window_size() == 512 diff --git a/tests/test_config.py b/tests/test_config.py index 9ee18549..86c99a23 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,15 @@ import pathlib +import pytest import subprocess - +import unittest.mock import yaml + +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.utils import Assert +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.config_utils.data_type import DataType + from fast_llm.models.auto import trainer_registry @@ -51,3 +58,29 @@ def test_validate_example_config(): (pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text() ) trainer_registry["gpt"].from_dict(fast_llm_config_dict) + + +def test_do_use_flash_attention(): + # Create a mock DistributedConfig + mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) + + # Test case 1: use_flash_attention is True and training_dtype is float16 + config = TransformerConfig(use_flash_attention=True, window_size=None) + mock_distributed_config.training_dtype = DataType.float16 + assert config.do_use_flash_attention(mock_distributed_config) is True + + # Test case 2: use_flash_attention is False + config = TransformerConfig(use_flash_attention=False, window_size=None) + mock_distributed_config.training_dtype = DataType.float16 + assert config.do_use_flash_attention(mock_distributed_config) is False + + # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 + config = TransformerConfig(use_flash_attention=True, window_size=None) + mock_distributed_config.training_dtype = DataType.float32 + assert config.do_use_flash_attention(mock_distributed_config) is False + + # Test case 4: use_flash_attention is False and window_size is not None + config = TransformerConfig(use_flash_attention=False, window_size=512) + mock_distributed_config.training_dtype = DataType.float32 + with pytest.raises(AssertionError): + config.do_use_flash_attention(mock_distributed_config)