-
Notifications
You must be signed in to change notification settings - Fork 4.8k
[SP] add SP deny list instead of allow #7887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
aec2c90
49e0310
2f8e77c
ce69dc0
952c3ae
6e3f2cb
874ec62
7d0a136
5868135
b0e05f0
89058fc
463cb30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |
| import deepspeed.comm as dist | ||
| import importlib.metadata | ||
| import math | ||
| import re | ||
| import torch | ||
| import torch.distributed.nn | ||
|
|
||
|
|
@@ -121,6 +122,12 @@ def __init__( | |
| self.skip_all_but_last_attention_debug_mode = False | ||
| self.rotating_layer_counter = 0 # used for dev work | ||
|
|
||
| self.core_attn_implementation = None # set by register_with_transformers | ||
| self._flex_block_mask_cls = None # set by register_with_transformers | ||
| self._flex_create_block_mask = None # set by register_with_transformers | ||
| self._flex_block_mask_cached = None # cached BlockMask for flex_attention | ||
| self._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation | ||
|
|
||
| self.local_q_head_count = attn_head_count // self.world_size | ||
|
|
||
| # if we have 4 kv heads and sp 8, we need to replicate kv heads 2x | ||
|
|
@@ -272,19 +279,11 @@ def forward( | |
| key = rearrange(key, "bs hc sl hs -> sl bs hc hs") # .contiguous() | ||
| value = rearrange(value, "bs hc sl hs -> sl bs hc hs") # .contiguous() | ||
|
|
||
| # core attn like FA2 expects an unsharded `position_ids` - without which packed samples | ||
| # will return loss=nan. | ||
| # | ||
| # XXX: need to figure out if we can do the same for SDPA - as it doesn't require this and | ||
| # wants an attention mask, so possibly doing this for FA2 only? | ||
| # | ||
| # Ideally we would passing the original unsharded position_ids - but we have no way to pass | ||
| # it here as HF Transformers drops unexpected keys in `batch` - so either we need to stash | ||
| # it somewhere in UlyssesSPDataLoaderAdapter and retrieve it here or we could gather it once | ||
| # per batch and stash it inside `module` arg - I already have a machinery to figure out | ||
| # which layer number is being called below in the skip_all_but_last_attention_debug_mode | ||
| # code where rotating_layer_counter is used - so we could calculate it on the first layer | ||
| # and re-use on the remaining layers | ||
| # All attention backends need unsharded position_ids after the all-to-all. | ||
| # FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention | ||
| # need them to be monotonically increasing so causal masking works correctly. | ||
| # UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding, | ||
| # so after gathering here they reconstruct to the correct global positions. | ||
| if "position_ids" in kwargs: | ||
| position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)] | ||
| dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group) | ||
|
|
@@ -311,6 +310,36 @@ def forward( | |
| if self.kv_replication_factor > 1: | ||
| module.num_key_value_groups = query_layer.size(-3) // key_layer.size(-3) | ||
|
|
||
| # For flex_attention: the wrapper preserved the BlockMask from the model, but it | ||
| # was built for the local shard's sequence length. Rebuild it for the full gathered | ||
| # sequence length after the all-to-all. | ||
| # XXX: currently hardcodes a causal mask_mod — models with sliding window or other | ||
| # non-standard patterns would need the mask_mod extracted from the original BlockMask. | ||
| if self._flex_block_mask_cls is not None and isinstance(attention_mask, self._flex_block_mask_cls): | ||
| seq_len = query_layer.shape[2] | ||
| batch_size = query_layer.shape[0] | ||
| cache_key = (batch_size, seq_len) | ||
|
|
||
| # Cache the BlockMask — create_block_mask is expensive and the mask is the | ||
| # same for all layers within a forward pass. Only rebuild when dimensions change. | ||
| if self._flex_block_mask_cache_key != cache_key: | ||
|
|
||
| def causal_mask(batch_idx, head_idx, q_idx, kv_idx): | ||
| return q_idx >= kv_idx | ||
|
|
||
| self._flex_block_mask_cached = self._flex_create_block_mask( | ||
| mask_mod=causal_mask, | ||
| B=batch_size, | ||
| H=None, | ||
| Q_LEN=seq_len, | ||
| KV_LEN=seq_len, | ||
| device=query_layer.device, | ||
| _compile=True, | ||
| ) | ||
| self._flex_block_mask_cache_key = cache_key | ||
|
|
||
| attention_mask = self._flex_block_mask_cached | ||
|
|
||
| if not self.skip_all_but_last_attention_debug_mode: | ||
| # expects: [bs hc_l sl hs] | ||
| context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, *args, | ||
|
|
@@ -411,15 +440,34 @@ def register_with_transformers( | |
| # if we don't have the model yet at this stage | ||
| hf_model_config = AutoConfig.from_pretrained(model_name_or_path) | ||
|
|
||
| supported_attn_implementation = ["flash_attention_2", "flash_attention_3", "sdpa"] | ||
| if core_attn_implementation not in supported_attn_implementation: | ||
| # notes on the excluded ones: | ||
| # - eager: The problem is that `eager` wants an attention_mask and it creates the wrong attention mask it seems if we don't provide one - it's possible that we could somehow solve this, but it's also unlikely someone will want to use the slow eager attention with sequence parallelism | ||
| # - flex_attention: haven't tried | ||
|
|
||
| model_attn_implementation = getattr(hf_model_config, "_attn_implementation", None) | ||
| if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation: | ||
| raise ValueError( | ||
| f"core_attn_implementation='{core_attn_implementation}' does not match " | ||
| f"model config attn_implementation='{model_attn_implementation}'. " | ||
| "Set both to the same value so sequence-parallel wrapper can intercept the active attention path.") | ||
|
|
||
| # eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back | ||
| # to is_causal=True like sdpa — so it's incompatible with SP which discards masks. | ||
| unsupported_attn_implementation = ["eager", "paged|eager"] | ||
| if core_attn_implementation in unsupported_attn_implementation: | ||
| raise ValueError( | ||
| f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence" | ||
| f" parallelism. Set core_attn_implementation arg to one of {supported_attn_implementation}.") | ||
| f" parallelism because it requires a 4D attention_mask (O(n²) memory)." | ||
| f" Use 'flash_attention_2', 'flash_attention_3', 'flex_attention', 'sdpa'," | ||
| f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2').") | ||
|
|
||
| # Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers. | ||
| # Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS. | ||
| is_hub_kernel_attn = (isinstance(core_attn_implementation, str) and re.search( | ||
| r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", core_attn_implementation) is not None) | ||
| if is_hub_kernel_attn: | ||
| try: | ||
| from transformers.modeling_flash_attention_utils import lazy_import_flash_attention | ||
| except ImportError as e: | ||
| raise ImportError("Hub kernel attention requires a transformers version exposing " | ||
| "`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`.") from e | ||
| lazy_import_flash_attention(core_attn_implementation) | ||
|
|
||
| if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS: | ||
| raise ValueError( | ||
|
|
@@ -448,6 +496,16 @@ def register_with_transformers( | |
| global_seq_length=global_seq_length, | ||
| disable_in_eval=disable_in_eval, | ||
| ) | ||
| uattn.core_attn_implementation = core_attn_implementation | ||
|
|
||
| # Import flex_attention utilities once; stored on the instance for use in | ||
| # both the wrapper (to detect BlockMask) and forward() (to rebuild it). | ||
| uattn._flex_block_mask_cls = None | ||
| uattn._flex_create_block_mask = None | ||
| if core_attn_implementation == "flex_attention": | ||
| from torch.nn.attention.flex_attention import BlockMask, create_block_mask | ||
| uattn._flex_block_mask_cls = BlockMask | ||
| uattn._flex_create_block_mask = create_block_mask | ||
|
|
||
| def uattn_wrapper( | ||
| module: torch.nn.Module, | ||
|
|
@@ -459,27 +517,38 @@ def uattn_wrapper( | |
| **kwargs, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
|
|
||
| # We are relaying on position_ids for SP to work so attention_mask has to be None | ||
| # the problem is that HF currently doesn't know anything about ALL_ATTENTION_FUNCTIONS["ulysses"] so it doesn't make a special case like for "flash_attention_2" and "sdpa" and it creates an attention mask on the fly and it breaks things. | ||
| attention_mask = None | ||
| # SP relies on position_ids (not attention_mask) for causal masking. | ||
| # HF doesn't know about the SP wrapper, so it creates an attention_mask for | ||
| # the local shard's sequence length — which is invalid after the SP all-to-all | ||
| # gathers the full sequence. A 4D mask at full sequence length would also be | ||
| # O(n²) memory. So we discard 4D tensor masks. | ||
| # | ||
| # Keep BlockMask (flex_attention) — it's a compressed sparse representation. | ||
| # It will be rebuilt for the full gathered sequence in forward(). | ||
| if uattn._flex_block_mask_cls is not None and isinstance(attention_mask, uattn._flex_block_mask_cls): | ||
| pass # keep BlockMask — will be rebuilt in forward() for gathered seq len | ||
| else: | ||
| attention_mask = None | ||
|
|
||
| attn_output, attn_weights = uattn( | ||
| module, | ||
| query, | ||
| key, | ||
| value, | ||
| attention_mask, | ||
| # XXX: fixme | ||
| *args, | ||
| **kwargs, | ||
| ) | ||
| return attn_output, attn_weights | ||
|
|
||
| # We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper) | ||
| # The problem with this approach is that we are missing on all the special use cases in HF Transformers that do things like: if self.config._attn_implementation == "flash_attention_2": ... | ||
| # So instead we hack `ALL_ATTENTION_FUNCTIONS` to override all existing keys with our implementation, since it only gets used at the point of calling the attention and that's what we want, all other code branches relying on the original core `attn_implementation` will still be executed. This is what we called "Being John Malkovich" | ||
| for key in ALL_ATTENTION_FUNCTIONS.keys(): | ||
| ALL_ATTENTION_FUNCTIONS[key] = uattn_wrapper | ||
| # The problem with that approach is that we'd miss all the special-case branches in | ||
| # HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...` | ||
| # So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS | ||
| # with our wrapper. All other code paths relying on the original core attn_implementation | ||
| # will still be executed — we only intercept at the point of calling attention. | ||
| # This is what we called "Being John Malkovich". | ||
| ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper | ||
|
|
||
| return mpu | ||
|
|
||
|
|
@@ -574,6 +643,18 @@ def refill(self): | |
| micro_batches = defaultdict(dict) | ||
| # XXX: replace with more efficient all-to-all? | ||
|
|
||
| # position_ids must exist before sharding so that after all_gather in | ||
| # UlyssesSPAttentionHF.forward() they reconstruct to correct global positions. | ||
| # Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER | ||
| # sharding, which after all_gather looks like packed sequences and breaks | ||
| # sdpa/flex_attention causal masking. | ||
| if "position_ids" not in batch: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about this. This might lead to a user getting the wrong behavior if they packed samples but forgot to supply pos ids. Should we simply assert if pos ids aren't there and not potentially create invalid pos ids? I agree there needs to be a check and it's not there.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, It would need to be in the TRL trainer, for the collator to always provide position_ids when SP is enabled, so the adapter never needs to generate them. I Can try to fix it there.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, Kashif. And probably then add an assert on SP side if pos id isn't there? |
||
| raise ValueError("Ulysses SP requires `position_ids` in every dataloader batch so that " | ||
| "each token retains its correct global position after sequence sharding. " | ||
| "For non-packed sequences: position_ids = torch.arange(seq_len) per sample. " | ||
| "For packed sequences: position_ids must reset at document boundaries. " | ||
| "Ensure your data collator includes position_ids in its output.") | ||
|
|
||
| # we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank | ||
| seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device) | ||
| seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -301,3 +301,87 @@ def test_disable_in_eval(self): | |
| # Verify: with disable_in_eval=True, full sequence input should produce | ||
| # the same output as baseline (SP is bypassed) | ||
| torch_assert_equal(logits_baseline, logits_sp) | ||
|
|
||
|
|
||
| class TestUlyssesSPHFHubKernel(DistributedTest): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since your PR adds support for flex attention would it be difficult to add a test exercising this path? |
||
| world_size = 2 | ||
|
|
||
| def test_register_hub_kernel_attn(self, monkeypatch): | ||
| """Test hub-kernel attention strings are registered before validation. | ||
| This verifies that DeepSpeed can accept kernel-based attention implementations | ||
| by triggering transformers' lazy registration path prior to checking | ||
| ALL_ATTENTION_FUNCTIONS. | ||
| """ | ||
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
|
|
||
| model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' | ||
| seq_length = 64 | ||
| sequence_parallel_size = self.world_size | ||
| micro_batch_size = 1 | ||
| hub_attn_implementation = 'kernels-community/flash-attn2' | ||
|
|
||
| called_with = [] | ||
| had_hub_key_before = hub_attn_implementation in ALL_ATTENTION_FUNCTIONS | ||
| original_sdpa = ALL_ATTENTION_FUNCTIONS['sdpa'] | ||
|
|
||
| def _mock_lazy_import_flash_attention(implementation, attention_wrapper=None, allow_all_kernels=False): | ||
| called_with.append(implementation) | ||
| if implementation == hub_attn_implementation and implementation not in ALL_ATTENTION_FUNCTIONS: | ||
| # Mimic transformers hub-kernel registration behavior. | ||
| ALL_ATTENTION_FUNCTIONS.register(implementation, ALL_ATTENTION_FUNCTIONS['sdpa']) | ||
| return (None, None, None, None), None | ||
|
|
||
| monkeypatch.setattr( | ||
| 'transformers.modeling_flash_attention_utils.lazy_import_flash_attention', | ||
| _mock_lazy_import_flash_attention, | ||
| ) | ||
|
|
||
| try: | ||
| mpu = UlyssesSPAttentionHF.register_with_transformers( | ||
| model_name_or_path=model_name_or_path, | ||
| core_attn_implementation=hub_attn_implementation, | ||
| sequence_parallel_size=sequence_parallel_size, | ||
| micro_batch_size=micro_batch_size, | ||
| seq_length=seq_length, | ||
| seq_length_is_variable=True, | ||
| ) | ||
| assert ALL_ATTENTION_FUNCTIONS['sdpa'] is original_sdpa | ||
| assert ALL_ATTENTION_FUNCTIONS[hub_attn_implementation] is not original_sdpa | ||
| finally: | ||
| if not had_hub_key_before and hub_attn_implementation in ALL_ATTENTION_FUNCTIONS: | ||
| ALL_ATTENTION_FUNCTIONS.pop(hub_attn_implementation, None) | ||
|
|
||
| assert mpu is not None | ||
| assert called_with == [hub_attn_implementation] | ||
|
|
||
|
|
||
| class TestUlyssesSPHFAttnImplMismatch(DistributedTest): | ||
| world_size = 2 | ||
|
|
||
| def test_register_with_mismatched_attn_impl_raises(self): | ||
| from transformers import AutoConfig | ||
|
|
||
| model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' | ||
| seq_length = 64 | ||
| sequence_parallel_size = self.world_size | ||
| micro_batch_size = 1 | ||
|
|
||
| hf_config = AutoConfig.from_pretrained(model_name_or_path) | ||
| hf_config._attn_implementation = "sdpa" | ||
|
|
||
| class MockModel: | ||
| """Mock model wrapper exposing a transformers config attribute.""" | ||
|
|
||
| def __init__(self, config): | ||
| self.config = config | ||
|
|
||
| with pytest.raises(ValueError, match='does not match model config attn_implementation'): | ||
| UlyssesSPAttentionHF.register_with_transformers( | ||
| model_name_or_path=MockModel(hf_config), | ||
| core_attn_implementation='flash_attention_2', | ||
| sequence_parallel_size=sequence_parallel_size, | ||
| micro_batch_size=micro_batch_size, | ||
| seq_length=seq_length, | ||
| seq_length_is_variable=True, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we future proof this for fa and say any official flash attention version?