Skip to content

Commit d0c83e1

Browse files
committed
feat: refactor attention implementation and enhance model runner capabilities
- Updated `triton_attention` to fetch `attn_metadata` internally, simplifying function parameters. - Enhanced `FullStaticRunner` to conditionally return logits based on the new `graph_outputs_are_logits` attribute. - Improved `MultiBlockModelRunnerTemplate` to support capturing logits during graph execution and adjusted output tensor sizing accordingly. - Added methods to determine logits data type and size, ensuring compatibility with model configurations.
1 parent 0a0ba03 commit d0c83e1

3 files changed

Lines changed: 60 additions & 15 deletions

File tree

diffulex/attention/attn_impl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ def triton_attention(
5151
v: torch.Tensor,
5252
k_cache: torch.Tensor,
5353
v_cache: torch.Tensor,
54-
attn_metadata: AttnMetaDataBase,
55-
is_unified_layout: bool,
5654
) -> torch.Tensor:
5755
# Keep Triton JIT/autotune state out of torch.compile; CUDA graph capture
5856
# still records the launched kernels.
57+
from diffulex.attention import fetch_attn_metadata
58+
59+
attn_metadata: AttnMetaDataBase = fetch_attn_metadata()
60+
is_unified_layout = attn_metadata.kv_cache_layout == "unified"
5961
if k_cache.numel() and v_cache.numel():
6062
if attn_metadata.need_kv_cache_store:
6163
store_kv_cache = store_kv_cache_unified_layout if is_unified_layout else store_kv_cache_distinct_layout
@@ -143,11 +145,9 @@ def forward(
143145
if self.attn_impl != "triton":
144146
raise ValueError(f"Unsupported attn_impl: {self.attn_impl}")
145147

146-
attn_metadata: AttnMetaDataBase = self.fetch_attn_metadata()
147148
k_cache, v_cache = self.k_cache, self.v_cache
148-
is_unified_layout = attn_metadata.kv_cache_layout == "unified"
149149

150-
o = triton_attention(q, k, v, k_cache, v_cache, attn_metadata, is_unified_layout)
150+
o = triton_attention(q, k, v, k_cache, v_cache)
151151

152152
# Final reshape
153153
return rearrange(o, "s nh hd -> s (nh hd)").contiguous()

diffulex/strategy_template/multi_block/engine/full_static_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,6 @@ def run_decode(
8888
self.owner._bind_decode_graph_extra_metadata(attn_metadata, graph_vars, num_tokens)
8989

9090
graph.replay()
91+
if bool(getattr(self.owner, "graph_outputs_are_logits", False)):
92+
return graph_vars["outputs"][:num_tokens]
9193
return self.owner.model.compute_logits(graph_vars["outputs"][:num_tokens])

diffulex/strategy_template/multi_block/engine/model_runner.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import torch.distributed as dist
8+
import torch._inductor.config as inductor_config
89

910
from tqdm import tqdm
1011

@@ -91,14 +92,22 @@ def _patch_model_forward_for_cuda_graph_capture(self, num_tokens: int):
9192

9293
original_forward = self.model.forward
9394
mode = str(getattr(self.config, "torch_compile_mode", "reduce-overhead") or "reduce-overhead")
95+
compile_config_patch = {
96+
# We already wrap the compiled forward in our own CUDA graph.
97+
# Inductor's internal cudagraph trees try to replay during outer
98+
# capture and fail with "Cannot prepare for replay during capturing".
99+
"triton.cudagraphs": False,
100+
"triton.cudagraph_trees": False,
101+
}
94102
try:
95-
self.model.forward = torch.compile(
96-
torch.no_grad()(original_forward),
97-
mode=mode,
98-
fullgraph=False,
99-
dynamic=False,
100-
)
101-
yield True
103+
with inductor_config.patch(compile_config_patch):
104+
self.model.forward = torch.compile(
105+
torch.no_grad()(original_forward),
106+
mode=mode,
107+
fullgraph=False,
108+
dynamic=False,
109+
)
110+
yield True
102111
finally:
103112
self.model.forward = original_forward
104113

@@ -134,9 +143,14 @@ def _capture_model_forward_graph(
134143
num_tokens: int,
135144
*,
136145
allow_compile: bool = False,
146+
capture_logits: bool = False,
137147
) -> torch.cuda.CUDAGraph:
138148
def run_once() -> None:
139-
outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens])
149+
hidden_states = self.model(input_ids[:num_tokens], positions[:num_tokens])
150+
if capture_logits:
151+
outputs[:num_tokens] = self.model.compute_logits(hidden_states)
152+
else:
153+
outputs[:num_tokens] = hidden_states
140154

141155
stream = self._get_graph_capture_stream()
142156
pool = self._get_graph_pool()
@@ -291,6 +305,25 @@ def _model_hidden_dtype(self) -> torch.dtype:
291305
except StopIteration:
292306
return torch.get_default_dtype()
293307

308+
def _model_logits_dtype(self) -> torch.dtype:
309+
return self._model_hidden_dtype()
310+
311+
def _model_logits_size(self) -> int:
312+
lm_head = getattr(self.model, "lm_head", None)
313+
partition_size = getattr(lm_head, "num_embeddings_per_partition", None)
314+
if partition_size is not None:
315+
return int(partition_size)
316+
vocab_size = getattr(self.config, "tokenizer_vocab_size", None) or getattr(self.config.hf_config, "vocab_size")
317+
if self.world_size <= 1:
318+
return int(vocab_size)
319+
return int(vocab_size) // int(self.world_size)
320+
321+
def _can_capture_logits_in_graph(self) -> bool:
322+
# TP lm_head captures all-gather/NCCL and rank-local None outputs.
323+
# Logits buffers are vocab-sized, so keep the first version to the
324+
# single-request SDAR eval path before enabling broader serving shapes.
325+
return self.world_size == 1 and int(self.config.max_num_reqs) == 1
326+
294327
def _ensure_runtime_static_buffers(
295328
self,
296329
*,
@@ -765,6 +798,9 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
765798

766799
max_num_tokens = max_num_seqs * chunk_size
767800
device = self._cuda_graph_device()
801+
capture_logits = self._can_capture_logits_in_graph()
802+
self.graph_outputs_are_logits = capture_logits
803+
output_size = self._model_logits_size() if capture_logits else hf_config.hidden_size
768804

769805
input_ids = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
770806
positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
@@ -775,7 +811,7 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
775811
status_table = torch.zeros(max_num_seqs, dtype=torch.int32, device=device)
776812
prefix_lens = torch.zeros(max_num_seqs, dtype=torch.int32, device=device)
777813
padded_prefix_lens = torch.zeros(max_num_seqs, dtype=torch.int32, device=device)
778-
outputs = torch.zeros(max_num_tokens, hf_config.hidden_size, dtype=self._model_hidden_dtype(), device=device)
814+
outputs = torch.zeros(max_num_tokens, output_size, dtype=self._model_logits_dtype(), device=device)
779815

780816
cu_seqlens_q = torch.zeros(max_num_seqs + 1, dtype=torch.int32, device=device)
781817
for i in range(max_num_seqs + 1):
@@ -819,7 +855,14 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
819855
padded_prefix_lens=padded_prefix_lens[:num_seqs],
820856
)
821857

822-
graph = self._capture_model_forward_graph(input_ids, positions, outputs, num_tokens, allow_compile=True)
858+
graph = self._capture_model_forward_graph(
859+
input_ids,
860+
positions,
861+
outputs,
862+
num_tokens,
863+
allow_compile=True,
864+
capture_logits=capture_logits,
865+
)
823866
if self.graph_pool is None:
824867
self.graph_pool = graph.pool()
825868
self.graphs[num_tokens] = graph

0 commit comments

Comments
 (0)