Skip to content

Support for max_window_layers #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def _query_key_value_backward(
input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value")))
return input_grad


def _decide_window_size(self) -> int | None:
# NOTE: This is a temporal solution for qwen 2.X
# https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71
# TODO: make universal per layer config
window_size = self._config.window_size
if (
self._config.max_window_layers is not None
and self._layer_index < self._config.max_window_layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is incorrect because layer index starts at 1 for some reason.

):
window_size = None

return window_size

def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]:
sequence_first = kwargs[TransformerKwargs.sequence_first]
query, key_value = self._query_key_value(input_, sequence_first)
Expand Down Expand Up @@ -323,13 +337,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])


window_size = self._decide_window_size()

if self._use_flash_attention:
input_ = flash_attn(
query,
key,
value,
dropout_p=self._config.attention_dropout if self.training else 0.0,
window_size=self._config.window_size,
window_size=window_size,
causal=True,
generator=self._tensor_space.distributed.tp_generator,
softmax_scale=self._softmax_scale,
Expand Down
14 changes: 13 additions & 1 deletion fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
max_window_layers: int | None = Field(
default=None,
desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.",
hint=FieldHint.optional,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
# normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto
mlp_recompute_level: MLPRecomputeLevel = Field(
default=MLPRecomputeLevel.none,
Expand Down Expand Up @@ -571,4 +577,10 @@ def _validate(self) -> None:
Assert.geq(scale, 0)

def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)
use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)

# Config parameter `window_size` only can be used with flash attention
if not use_flash_attention:
Assert.is_(self.window_size, None)

return use_flash_attention
22 changes: 22 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest.mock
from fast_llm.layers.transformer.attention import Attention
from fast_llm.layers.transformer.config import TransformerConfig


def test_decide_window_size():
attention = unittest.mock.Mock(spec=Attention)
attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method

# Arrange - Case 1: window_size is returned (layer_index >= max_window_layers)
attention._config = TransformerConfig(window_size=512, max_window_layers=2)
attention._layer_index = 2
assert attention._decide_window_size() == 512

# Arrange - Case 2: window_size is None (layer_index < max_window_layers)
attention._config = TransformerConfig(window_size=512, max_window_layers=2)
attention._layer_index = 1
assert attention._decide_window_size() is None

# Arrange - Case 3: max_window_layers is None (always return window_size)
attention._config = TransformerConfig(window_size=512, max_window_layers=None)
assert attention._decide_window_size() == 512
35 changes: 34 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import pathlib
import pytest
import subprocess

import unittest.mock
import yaml


from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.utils import Assert
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.data_type import DataType

from fast_llm.models.auto import trainer_registry


Expand Down Expand Up @@ -51,3 +58,29 @@ def test_validate_example_config():
(pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text()
)
trainer_registry["gpt"].from_dict(fast_llm_config_dict)


def test_do_use_flash_attention():
# Create a mock DistributedConfig
mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig)

# Test case 1: use_flash_attention is True and training_dtype is float16
config = TransformerConfig(use_flash_attention=True, window_size=None)
mock_distributed_config.training_dtype = DataType.float16
assert config.do_use_flash_attention(mock_distributed_config) is True

# Test case 2: use_flash_attention is False
config = TransformerConfig(use_flash_attention=False, window_size=None)
mock_distributed_config.training_dtype = DataType.float16
assert config.do_use_flash_attention(mock_distributed_config) is False

# Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16
config = TransformerConfig(use_flash_attention=True, window_size=None)
mock_distributed_config.training_dtype = DataType.float32
assert config.do_use_flash_attention(mock_distributed_config) is False

# Test case 4: use_flash_attention is False and window_size is not None
config = TransformerConfig(use_flash_attention=False, window_size=512)
mock_distributed_config.training_dtype = DataType.float32
with pytest.raises(AssertionError):
config.do_use_flash_attention(mock_distributed_config)