Skip to content
137 changes: 109 additions & 28 deletions deepspeed/runtime/sequence_parallel/ulysses_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import deepspeed.comm as dist
import importlib.metadata
import math
import re
import torch
import torch.distributed.nn

Expand Down Expand Up @@ -121,6 +122,12 @@ def __init__(
self.skip_all_but_last_attention_debug_mode = False
self.rotating_layer_counter = 0 # used for dev work

self.core_attn_implementation = None # set by register_with_transformers
self._flex_block_mask_cls = None # set by register_with_transformers
self._flex_create_block_mask = None # set by register_with_transformers
self._flex_block_mask_cached = None # cached BlockMask for flex_attention
self._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation

self.local_q_head_count = attn_head_count // self.world_size

# if we have 4 kv heads and sp 8, we need to replicate kv heads 2x
Expand Down Expand Up @@ -272,19 +279,11 @@ def forward(
key = rearrange(key, "bs hc sl hs -> sl bs hc hs") # .contiguous()
value = rearrange(value, "bs hc sl hs -> sl bs hc hs") # .contiguous()

# core attn like FA2 expects an unsharded `position_ids` - without which packed samples
# will return loss=nan.
#
# XXX: need to figure out if we can do the same for SDPA - as it doesn't require this and
# wants an attention mask, so possibly doing this for FA2 only?
#
# Ideally we would passing the original unsharded position_ids - but we have no way to pass
# it here as HF Transformers drops unexpected keys in `batch` - so either we need to stash
# it somewhere in UlyssesSPDataLoaderAdapter and retrieve it here or we could gather it once
# per batch and stash it inside `module` arg - I already have a machinery to figure out
# which layer number is being called below in the skip_all_but_last_attention_debug_mode
# code where rotating_layer_counter is used - so we could calculate it on the first layer
# and re-use on the remaining layers
# All attention backends need unsharded position_ids after the all-to-all.
# FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention
# need them to be monotonically increasing so causal masking works correctly.
# UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding,
# so after gathering here they reconstruct to the correct global positions.
if "position_ids" in kwargs:
position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)]
dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group)
Expand All @@ -311,6 +310,36 @@ def forward(
if self.kv_replication_factor > 1:
module.num_key_value_groups = query_layer.size(-3) // key_layer.size(-3)

# For flex_attention: the wrapper preserved the BlockMask from the model, but it
# was built for the local shard's sequence length. Rebuild it for the full gathered
# sequence length after the all-to-all.
# XXX: currently hardcodes a causal mask_mod — models with sliding window or other
# non-standard patterns would need the mask_mod extracted from the original BlockMask.
if self._flex_block_mask_cls is not None and isinstance(attention_mask, self._flex_block_mask_cls):
seq_len = query_layer.shape[2]
batch_size = query_layer.shape[0]
cache_key = (batch_size, seq_len)

# Cache the BlockMask — create_block_mask is expensive and the mask is the
# same for all layers within a forward pass. Only rebuild when dimensions change.
if self._flex_block_mask_cache_key != cache_key:

def causal_mask(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx

self._flex_block_mask_cached = self._flex_create_block_mask(
mask_mod=causal_mask,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=query_layer.device,
_compile=True,
)
self._flex_block_mask_cache_key = cache_key

attention_mask = self._flex_block_mask_cached

if not self.skip_all_but_last_attention_debug_mode:
# expects: [bs hc_l sl hs]
context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, *args,
Expand Down Expand Up @@ -411,15 +440,34 @@ def register_with_transformers(
# if we don't have the model yet at this stage
hf_model_config = AutoConfig.from_pretrained(model_name_or_path)

supported_attn_implementation = ["flash_attention_2", "flash_attention_3", "sdpa"]
if core_attn_implementation not in supported_attn_implementation:
# notes on the excluded ones:
# - eager: The problem is that `eager` wants an attention_mask and it creates the wrong attention mask it seems if we don't provide one - it's possible that we could somehow solve this, but it's also unlikely someone will want to use the slow eager attention with sequence parallelism
# - flex_attention: haven't tried

model_attn_implementation = getattr(hf_model_config, "_attn_implementation", None)
if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation:
raise ValueError(
f"core_attn_implementation='{core_attn_implementation}' does not match "
f"model config attn_implementation='{model_attn_implementation}'. "
"Set both to the same value so sequence-parallel wrapper can intercept the active attention path.")

# eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back
# to is_causal=True like sdpa — so it's incompatible with SP which discards masks.
unsupported_attn_implementation = ["eager", "paged|eager"]
if core_attn_implementation in unsupported_attn_implementation:
raise ValueError(
f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence"
f" parallelism. Set core_attn_implementation arg to one of {supported_attn_implementation}.")
f" parallelism because it requires a 4D attention_mask (O(n²) memory)."
f" Use 'flash_attention_2', 'flash_attention_3', 'flex_attention', 'sdpa',"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we future proof this for fa and say any official flash attention version?

f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2').")

# Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers.
# Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS.
is_hub_kernel_attn = (isinstance(core_attn_implementation, str) and re.search(
r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", core_attn_implementation) is not None)
if is_hub_kernel_attn:
try:
from transformers.modeling_flash_attention_utils import lazy_import_flash_attention
except ImportError as e:
raise ImportError("Hub kernel attention requires a transformers version exposing "
"`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`.") from e
lazy_import_flash_attention(core_attn_implementation)

if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS:
raise ValueError(
Expand Down Expand Up @@ -448,6 +496,16 @@ def register_with_transformers(
global_seq_length=global_seq_length,
disable_in_eval=disable_in_eval,
)
uattn.core_attn_implementation = core_attn_implementation

# Import flex_attention utilities once; stored on the instance for use in
# both the wrapper (to detect BlockMask) and forward() (to rebuild it).
uattn._flex_block_mask_cls = None
uattn._flex_create_block_mask = None
if core_attn_implementation == "flex_attention":
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
uattn._flex_block_mask_cls = BlockMask
uattn._flex_create_block_mask = create_block_mask

def uattn_wrapper(
module: torch.nn.Module,
Expand All @@ -459,27 +517,38 @@ def uattn_wrapper(
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:

# We are relaying on position_ids for SP to work so attention_mask has to be None
# the problem is that HF currently doesn't know anything about ALL_ATTENTION_FUNCTIONS["ulysses"] so it doesn't make a special case like for "flash_attention_2" and "sdpa" and it creates an attention mask on the fly and it breaks things.
attention_mask = None
# SP relies on position_ids (not attention_mask) for causal masking.
# HF doesn't know about the SP wrapper, so it creates an attention_mask for
# the local shard's sequence length — which is invalid after the SP all-to-all
# gathers the full sequence. A 4D mask at full sequence length would also be
# O(n²) memory. So we discard 4D tensor masks.
#
# Keep BlockMask (flex_attention) — it's a compressed sparse representation.
# It will be rebuilt for the full gathered sequence in forward().
if uattn._flex_block_mask_cls is not None and isinstance(attention_mask, uattn._flex_block_mask_cls):
pass # keep BlockMask — will be rebuilt in forward() for gathered seq len
else:
attention_mask = None

attn_output, attn_weights = uattn(
module,
query,
key,
value,
attention_mask,
# XXX: fixme
*args,
**kwargs,
)
return attn_output, attn_weights

# We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper)
# The problem with this approach is that we are missing on all the special use cases in HF Transformers that do things like: if self.config._attn_implementation == "flash_attention_2": ...
# So instead we hack `ALL_ATTENTION_FUNCTIONS` to override all existing keys with our implementation, since it only gets used at the point of calling the attention and that's what we want, all other code branches relying on the original core `attn_implementation` will still be executed. This is what we called "Being John Malkovich"
for key in ALL_ATTENTION_FUNCTIONS.keys():
ALL_ATTENTION_FUNCTIONS[key] = uattn_wrapper
# The problem with that approach is that we'd miss all the special-case branches in
# HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...`
# So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS
# with our wrapper. All other code paths relying on the original core attn_implementation
# will still be executed — we only intercept at the point of calling attention.
# This is what we called "Being John Malkovich".
ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper

return mpu

Expand Down Expand Up @@ -574,6 +643,18 @@ def refill(self):
micro_batches = defaultdict(dict)
# XXX: replace with more efficient all-to-all?

# position_ids must exist before sharding so that after all_gather in
# UlyssesSPAttentionHF.forward() they reconstruct to correct global positions.
# Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER
# sharding, which after all_gather looks like packed sequences and breaks
# sdpa/flex_attention causal masking.
if "position_ids" not in batch:
Copy link
Collaborator

@stas00 stas00 Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. This might lead to a user getting the wrong behavior if they packed samples but forgot to supply pos ids. Should we simply assert if pos ids aren't there and not potentially create invalid pos ids?

I agree there needs to be a check and it's not there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, It would need to be in the TRL trainer, for the collator to always provide position_ids when SP is enabled, so the adapter never needs to generate them. I Can try to fix it there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Kashif.

And probably then add an assert on SP side if pos id isn't there?

raise ValueError("Ulysses SP requires `position_ids` in every dataloader batch so that "
"each token retains its correct global position after sequence sharding. "
"For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
"For packed sequences: position_ids must reset at document boundaries. "
"Ensure your data collator includes position_ids in its output.")

# we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank
seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device)
seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)]
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/ulysses_alst/test_ulysses_sp_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,87 @@ def test_disable_in_eval(self):
# Verify: with disable_in_eval=True, full sequence input should produce
# the same output as baseline (SP is bypassed)
torch_assert_equal(logits_baseline, logits_sp)


class TestUlyssesSPHFHubKernel(DistributedTest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since your PR adds support for flex attention would it be difficult to add a test exercising this path?

world_size = 2

def test_register_hub_kernel_attn(self, monkeypatch):
"""Test hub-kernel attention strings are registered before validation.
This verifies that DeepSpeed can accept kernel-based attention implementations
by triggering transformers' lazy registration path prior to checking
ALL_ATTENTION_FUNCTIONS.
"""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
seq_length = 64
sequence_parallel_size = self.world_size
micro_batch_size = 1
hub_attn_implementation = 'kernels-community/flash-attn2'

called_with = []
had_hub_key_before = hub_attn_implementation in ALL_ATTENTION_FUNCTIONS
original_sdpa = ALL_ATTENTION_FUNCTIONS['sdpa']

def _mock_lazy_import_flash_attention(implementation, attention_wrapper=None, allow_all_kernels=False):
called_with.append(implementation)
if implementation == hub_attn_implementation and implementation not in ALL_ATTENTION_FUNCTIONS:
# Mimic transformers hub-kernel registration behavior.
ALL_ATTENTION_FUNCTIONS.register(implementation, ALL_ATTENTION_FUNCTIONS['sdpa'])
return (None, None, None, None), None

monkeypatch.setattr(
'transformers.modeling_flash_attention_utils.lazy_import_flash_attention',
_mock_lazy_import_flash_attention,
)

try:
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=model_name_or_path,
core_attn_implementation=hub_attn_implementation,
sequence_parallel_size=sequence_parallel_size,
micro_batch_size=micro_batch_size,
seq_length=seq_length,
seq_length_is_variable=True,
)
assert ALL_ATTENTION_FUNCTIONS['sdpa'] is original_sdpa
assert ALL_ATTENTION_FUNCTIONS[hub_attn_implementation] is not original_sdpa
finally:
if not had_hub_key_before and hub_attn_implementation in ALL_ATTENTION_FUNCTIONS:
ALL_ATTENTION_FUNCTIONS.pop(hub_attn_implementation, None)

assert mpu is not None
assert called_with == [hub_attn_implementation]


class TestUlyssesSPHFAttnImplMismatch(DistributedTest):
world_size = 2

def test_register_with_mismatched_attn_impl_raises(self):
from transformers import AutoConfig

model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
seq_length = 64
sequence_parallel_size = self.world_size
micro_batch_size = 1

hf_config = AutoConfig.from_pretrained(model_name_or_path)
hf_config._attn_implementation = "sdpa"

class MockModel:
"""Mock model wrapper exposing a transformers config attribute."""

def __init__(self, config):
self.config = config

with pytest.raises(ValueError, match='does not match model config attn_implementation'):
UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=MockModel(hf_config),
core_attn_implementation='flash_attention_2',
sequence_parallel_size=sequence_parallel_size,
micro_batch_size=micro_batch_size,
seq_length=seq_length,
seq_length_is_variable=True,
)
Loading