From fe53fc8d1c61b24885cc142f51c59b2b1cd10956 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 20 Feb 2025 16:02:27 +0200 Subject: [PATCH 1/4] added max_wondows_layers --- fast_llm/layers/transformer/attention.py | 13 ++++++++++++- fast_llm/layers/transformer/config.py | 6 ++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b14e3971..a2948dea 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -323,13 +323,24 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + + # NOTE: This is a temporal solution for qwen 2.X + # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 + # TODO: make universal per layer config + window_size = self._config.window_size + if ( + self._config.max_window_layers is not None + and self._layer_index < self._config.max_window_layers + ): + window_size = None + if self._use_flash_attention: input_ = flash_attn( query, key, value, dropout_p=self._config.attention_dropout if self.training else 0.0, - window_size=self._config.window_size, + window_size=window_size, causal=True, generator=self._tensor_space.distributed.tp_generator, softmax_scale=self._softmax_scale, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 81d58a5a..f9275441 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -443,6 +443,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto mlp_recompute_level: MLPRecomputeLevel = Field( default=MLPRecomputeLevel.none, From 1d24759c43cca40af3d6d838c7b8a69ab7f7a4f0 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 20 Feb 2025 16:16:48 +0200 Subject: [PATCH 2/4] assert on window_size without flash attention --- fast_llm/layers/transformer/config.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f9275441..1b4e7749 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -577,4 +577,10 @@ def _validate(self) -> None: Assert.geq(scale, 0) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + # Config parameter `window_size` only can be used with flash attention + if not use_flash_attention: + Assert.is_(self.window_size, None) + + return use_flash_attention From 3e4714c0f84117a6fbd8692a23e0ed8566d4ddb2 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 20 Feb 2025 17:30:37 +0200 Subject: [PATCH 3/4] moved decision on sliding window size to a separate method for testability --- fast_llm/layers/transformer/attention.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index a2948dea..8071a086 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -274,6 +274,20 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value"))) return input_grad + + def _decide_window_size(self) -> int | None: + # NOTE: This is a temporal solution for qwen 2.X + # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 + # TODO: make universal per layer config + window_size = self._config.window_size + if ( + self._config.max_window_layers is not None + and self._layer_index < self._config.max_window_layers + ): + window_size = None + + return window_size + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[TransformerKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) @@ -324,15 +338,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) - # NOTE: This is a temporal solution for qwen 2.X - # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 - # TODO: make universal per layer config - window_size = self._config.window_size - if ( - self._config.max_window_layers is not None - and self._layer_index < self._config.max_window_layers - ): - window_size = None + window_size = self._decide_window_size() if self._use_flash_attention: input_ = flash_attn( From 341e845fe4b8dc9400681f6be6702fea9c81a7f5 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 20 Feb 2025 17:31:15 +0200 Subject: [PATCH 4/4] added test cases --- tests/test_attention.py | 22 ++++++++++++++++++++++ tests/test_config.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/test_attention.py diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 00000000..db856787 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,22 @@ +import unittest.mock +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerConfig + + +def test_decide_window_size(): + attention = unittest.mock.Mock(spec=Attention) + attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method + + # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) + attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._layer_index = 2 + assert attention._decide_window_size() == 512 + + # Arrange - Case 2: window_size is None (layer_index < max_window_layers) + attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._layer_index = 1 + assert attention._decide_window_size() is None + + # Arrange - Case 3: max_window_layers is None (always return window_size) + attention._config = TransformerConfig(window_size=512, max_window_layers=None) + assert attention._decide_window_size() == 512 diff --git a/tests/test_config.py b/tests/test_config.py index 9ee18549..86c99a23 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,15 @@ import pathlib +import pytest import subprocess - +import unittest.mock import yaml + +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.utils import Assert +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.config_utils.data_type import DataType + from fast_llm.models.auto import trainer_registry @@ -51,3 +58,29 @@ def test_validate_example_config(): (pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text() ) trainer_registry["gpt"].from_dict(fast_llm_config_dict) + + +def test_do_use_flash_attention(): + # Create a mock DistributedConfig + mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) + + # Test case 1: use_flash_attention is True and training_dtype is float16 + config = TransformerConfig(use_flash_attention=True, window_size=None) + mock_distributed_config.training_dtype = DataType.float16 + assert config.do_use_flash_attention(mock_distributed_config) is True + + # Test case 2: use_flash_attention is False + config = TransformerConfig(use_flash_attention=False, window_size=None) + mock_distributed_config.training_dtype = DataType.float16 + assert config.do_use_flash_attention(mock_distributed_config) is False + + # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 + config = TransformerConfig(use_flash_attention=True, window_size=None) + mock_distributed_config.training_dtype = DataType.float32 + assert config.do_use_flash_attention(mock_distributed_config) is False + + # Test case 4: use_flash_attention is False and window_size is not None + config = TransformerConfig(use_flash_attention=False, window_size=512) + mock_distributed_config.training_dtype = DataType.float32 + with pytest.raises(AssertionError): + config.do_use_flash_attention(mock_distributed_config)