55
66import torch
77import torch .distributed as dist
8+ import torch ._inductor .config as inductor_config
89
910from tqdm import tqdm
1011
@@ -91,14 +92,22 @@ def _patch_model_forward_for_cuda_graph_capture(self, num_tokens: int):
9192
9293 original_forward = self .model .forward
9394 mode = str (getattr (self .config , "torch_compile_mode" , "reduce-overhead" ) or "reduce-overhead" )
95+ compile_config_patch = {
96+ # We already wrap the compiled forward in our own CUDA graph.
97+ # Inductor's internal cudagraph trees try to replay during outer
98+ # capture and fail with "Cannot prepare for replay during capturing".
99+ "triton.cudagraphs" : False ,
100+ "triton.cudagraph_trees" : False ,
101+ }
94102 try :
95- self .model .forward = torch .compile (
96- torch .no_grad ()(original_forward ),
97- mode = mode ,
98- fullgraph = False ,
99- dynamic = False ,
100- )
101- yield True
103+ with inductor_config .patch (compile_config_patch ):
104+ self .model .forward = torch .compile (
105+ torch .no_grad ()(original_forward ),
106+ mode = mode ,
107+ fullgraph = False ,
108+ dynamic = False ,
109+ )
110+ yield True
102111 finally :
103112 self .model .forward = original_forward
104113
@@ -134,9 +143,14 @@ def _capture_model_forward_graph(
134143 num_tokens : int ,
135144 * ,
136145 allow_compile : bool = False ,
146+ capture_logits : bool = False ,
137147 ) -> torch .cuda .CUDAGraph :
138148 def run_once () -> None :
139- outputs [:num_tokens ] = self .model (input_ids [:num_tokens ], positions [:num_tokens ])
149+ hidden_states = self .model (input_ids [:num_tokens ], positions [:num_tokens ])
150+ if capture_logits :
151+ outputs [:num_tokens ] = self .model .compute_logits (hidden_states )
152+ else :
153+ outputs [:num_tokens ] = hidden_states
140154
141155 stream = self ._get_graph_capture_stream ()
142156 pool = self ._get_graph_pool ()
@@ -291,6 +305,25 @@ def _model_hidden_dtype(self) -> torch.dtype:
291305 except StopIteration :
292306 return torch .get_default_dtype ()
293307
308+ def _model_logits_dtype (self ) -> torch .dtype :
309+ return self ._model_hidden_dtype ()
310+
311+ def _model_logits_size (self ) -> int :
312+ lm_head = getattr (self .model , "lm_head" , None )
313+ partition_size = getattr (lm_head , "num_embeddings_per_partition" , None )
314+ if partition_size is not None :
315+ return int (partition_size )
316+ vocab_size = getattr (self .config , "tokenizer_vocab_size" , None ) or getattr (self .config .hf_config , "vocab_size" )
317+ if self .world_size <= 1 :
318+ return int (vocab_size )
319+ return int (vocab_size ) // int (self .world_size )
320+
321+ def _can_capture_logits_in_graph (self ) -> bool :
322+ # TP lm_head captures all-gather/NCCL and rank-local None outputs.
323+ # Logits buffers are vocab-sized, so keep the first version to the
324+ # single-request SDAR eval path before enabling broader serving shapes.
325+ return self .world_size == 1 and int (self .config .max_num_reqs ) == 1
326+
294327 def _ensure_runtime_static_buffers (
295328 self ,
296329 * ,
@@ -765,6 +798,9 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
765798
766799 max_num_tokens = max_num_seqs * chunk_size
767800 device = self ._cuda_graph_device ()
801+ capture_logits = self ._can_capture_logits_in_graph ()
802+ self .graph_outputs_are_logits = capture_logits
803+ output_size = self ._model_logits_size () if capture_logits else hf_config .hidden_size
768804
769805 input_ids = torch .zeros (max_num_tokens , dtype = torch .int64 , device = device )
770806 positions = torch .zeros (max_num_tokens , dtype = torch .int64 , device = device )
@@ -775,7 +811,7 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
775811 status_table = torch .zeros (max_num_seqs , dtype = torch .int32 , device = device )
776812 prefix_lens = torch .zeros (max_num_seqs , dtype = torch .int32 , device = device )
777813 padded_prefix_lens = torch .zeros (max_num_seqs , dtype = torch .int32 , device = device )
778- outputs = torch .zeros (max_num_tokens , hf_config . hidden_size , dtype = self ._model_hidden_dtype (), device = device )
814+ outputs = torch .zeros (max_num_tokens , output_size , dtype = self ._model_logits_dtype (), device = device )
779815
780816 cu_seqlens_q = torch .zeros (max_num_seqs + 1 , dtype = torch .int32 , device = device )
781817 for i in range (max_num_seqs + 1 ):
@@ -819,7 +855,14 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
819855 padded_prefix_lens = padded_prefix_lens [:num_seqs ],
820856 )
821857
822- graph = self ._capture_model_forward_graph (input_ids , positions , outputs , num_tokens , allow_compile = True )
858+ graph = self ._capture_model_forward_graph (
859+ input_ids ,
860+ positions ,
861+ outputs ,
862+ num_tokens ,
863+ allow_compile = True ,
864+ capture_logits = capture_logits ,
865+ )
823866 if self .graph_pool is None :
824867 self .graph_pool = graph .pool ()
825868 self .graphs [num_tokens ] = graph
0 commit comments