Skip to content

Commit c6aa3de

Browse files
tohtanaspikerheado1234
authored andcommitted
AutoSP: fix torch 2.9 fake propagation issues (#2)
* Fix AutoSP shape propagation fake mode reuse * Fix AutoSP torch 2.9 fake propagation * Fix AutoSP shard slice ordering * Add comments for AutoSP torch 2.9 fixes --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 6f73ea2 commit c6aa3de

File tree

4 files changed

+115
-8
lines changed

4 files changed

+115
-8
lines changed

deepspeed/compile/custom_ops/all_to_all.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import deepspeed.comm as dist
8+
from torch.utils._sympy.functions import FloorDiv
89
from .sp_dp_registry import get_group, is_setup, sp_size
910

1011

@@ -54,9 +55,25 @@ def all_to_all(
5455

5556
@torch.library.register_fake("autosp::all_to_all")
5657
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):
58+
59+
def maybe_restore_sharded_dim(dim: torch.SymInt, factor: int):
60+
# Torch 2.9 may keep `P * (s // P)` distinct from the original `s` during
61+
# fake shape propagation. When the local dim is exactly `FloorDiv(s, P)`,
62+
# restore the original symbol so downstream ops see a consistent sequence dim.
63+
node = getattr(dim, "node", None)
64+
if node is None:
65+
return dim * factor
66+
67+
expr = node.expr
68+
if isinstance(expr, FloorDiv) and expr.args[1] == factor:
69+
hint = node.hint * factor if node.has_hint() else None
70+
return node.shape_env.create_symintnode(expr.args[0], hint=hint)
71+
72+
return dim * factor
73+
5774
B, dim1, dim2, H = input.shape
5875
if scatter_idx == 1:
59-
return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H)
76+
return input.new_empty(B, dim1 // sp_size(), maybe_restore_sharded_dim(dim2, sp_size()), H)
6077
else:
6178
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)
6279

deepspeed/compile/passes/sp_compile.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
import deepspeed.comm as dist
11-
from torch._subclasses.fake_tensor import FakeTensorMode
11+
from torch._subclasses.fake_tensor import FakeTensorMode, maybe_get_fake_mode
1212
from torch.fx import GraphModule, Node
1313
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
1414
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@@ -80,7 +80,7 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs):
8080
seq_symint = val.shape[1]
8181
assert isinstance(
8282
seq_symint,
83-
torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`"
83+
torch.SymInt), f"expected sequence dimension to be of type {torch.SymInt!r} but found {type(seq_symint)!r}"
8484

8585
sym_seq_dim_node = find_node_by_name(gm, str(seq_symint))
8686
if sym_seq_dim_node is None:
@@ -184,15 +184,52 @@ def pass_canonicalize(gm: GraphModule, real_inputs):
184184

185185

186186
def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs):
187-
shape_env = ShapeEnv()
188-
fake_mode = FakeTensorMode(shape_env=shape_env)
187+
fake_mode = None
188+
for node in gm.graph.nodes:
189+
# Reuse the graph's existing fake mode when metadata is already present.
190+
# Its ShapeEnv owns the symbolic dims captured during tracing, so using a
191+
# fresh mode here can desynchronize fake inputs from graph metadata.
192+
if node.op == "placeholder" and "val" in node.meta:
193+
fake_val = node.meta["val"]
194+
if fake_val is not None and isinstance(fake_val, torch.Tensor):
195+
fake_mode = maybe_get_fake_mode(fake_val)
196+
elif fake_mode is None:
197+
fake_val = node.meta.get("example_value", node.meta.get("val"))
198+
if fake_val is not None and isinstance(fake_val, torch.Tensor):
199+
fake_mode = maybe_get_fake_mode(fake_val)
200+
if fake_mode is not None:
201+
break
202+
203+
if fake_mode is None:
204+
# Some graphs do not carry fake tensor metadata yet; create a fallback
205+
# mode so FakeTensorProp can still run shape-only execution.
206+
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
207+
189208
fake_inputs = []
190209
for t in real_inputs:
191210
if isinstance(t, torch.Tensor):
192211
fake_inputs.append(fake_mode.from_tensor(t))
193212
else:
194213
fake_inputs.append(t)
195-
FakeTensorProp(gm).propagate(*fake_inputs)
214+
215+
# Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path,
216+
# even though this pass only needs output metadata. Temporarily clear
217+
# attn_mask so shape propagation can proceed, then restore it immediately;
218+
# SDPA output shapes are still determined by Q/K/V shapes, not mask values.
219+
saved_sdpa_masks = []
220+
for attn_node in get_sdpa_nodes(gm):
221+
attn_mask = attn_node.kwargs.get("attn_mask")
222+
if attn_mask is not None:
223+
saved_sdpa_masks.append((attn_node, attn_mask))
224+
attn_node.update_kwarg("attn_mask", None)
225+
226+
try:
227+
# fake_inputs are already created under fake_mode above, so run
228+
# propagation without reconverting them into a different fake mode.
229+
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
230+
finally:
231+
for attn_node, attn_mask in saved_sdpa_masks:
232+
attn_node.update_kwarg("attn_mask", attn_mask)
196233

197234

198235
def apply_autosp(gm: GraphModule,

deepspeed/compile/util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,15 +591,21 @@ def shard_tensor_node(gm: GraphModule, tensor_node: Node):
591591
seq_len = val.shape[1]
592592

593593
assert isinstance(
594-
seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`"
594+
seq_len,
595+
torch.SymInt), (f"Expected sequence dimension to be {torch.SymInt!r} but instead found {type(seq_len)!r}")
595596

596597
symb_seq_int_node = find_node_by_name(gm, str(seq_len))
597598
assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}"
598599

599600
slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node)
600601
indices = (slice_all, slice_range)
601602

602-
with gm.graph.inserting_after(tensor_node):
603+
positions = {node: i for i, node in enumerate(gm.graph.nodes)}
604+
# Insert after the later dependency so the new getitem does not appear
605+
# before the symbolic slice nodes in graph order. Torch 2.9 bf16 can place
606+
# the SymInt placeholder after the tensor placeholder.
607+
anchor_node = slice_range if positions[slice_range] > positions[tensor_node] else tensor_node
608+
with gm.graph.inserting_after(anchor_node):
603609
sliced_node = gm.graph.call_function(
604610
operator.getitem,
605611
args=(tensor_node, indices),

tests/unit/v1/compile/test_compile_autosp.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn.functional as F
12+
from torch.fx import Graph, GraphModule
1213

1314
from deepspeed.utils.torch import required_torch_version
1415
from deepspeed.accelerator import get_accelerator
@@ -235,3 +236,49 @@ def test(self, seq_len):
235236
f"User '{user.name}' still references the unsharded input_ids_node"
236237
assert sliced_node in user.all_input_nodes, \
237238
f"User '{user.name}' does not reference the sliced node"
239+
240+
def test_preserves_topological_order_when_sym_placeholder_follows_input(self):
241+
import deepspeed.comm as _dist
242+
from deepspeed.compile.custom_ops import sp_dp_registry as _registry
243+
from deepspeed.compile.fx import find_node_by_name, get_node_shape_meta
244+
from deepspeed.compile.util import shard_tensor_node, get_input_id_node
245+
246+
# Regression test for the torch 2.9 bf16 trace where the SymInt
247+
# placeholder can appear after input_ids. shard_tensor_node must still
248+
# produce a lint-clean graph instead of inserting getitem before its
249+
# symbolic slice dependencies.
250+
gm, _ = create_gm_nodes(seq_len=64)
251+
input_ids_node = get_input_id_node(gm)
252+
seq_symint = get_node_shape_meta(input_ids_node).shape[1]
253+
sym_seq_node = find_node_by_name(gm, str(seq_symint))
254+
assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph"
255+
256+
nodes = list(gm.graph.nodes)
257+
input_idx = nodes.index(input_ids_node)
258+
sym_idx = nodes.index(sym_seq_node)
259+
assert sym_idx < input_idx, "Expected source graph to place the symbolic placeholder before input_ids"
260+
261+
# Reorder placeholders to mirror the torch 2.9 bf16 trace where the symbolic
262+
# sequence placeholder can appear after input_ids.
263+
reordered_nodes = nodes[:]
264+
reordered_nodes.pop(input_idx)
265+
reordered_nodes.insert(sym_idx, input_ids_node)
266+
reordered_nodes.pop(sym_idx + 1)
267+
reordered_nodes.insert(input_idx, sym_seq_node)
268+
269+
reordered_graph = Graph()
270+
env = {}
271+
for node in reordered_nodes:
272+
new_node = reordered_graph.node_copy(node, lambda n: env[n])
273+
new_node.meta = node.meta.copy()
274+
env[node] = new_node
275+
reordered_graph.lint()
276+
277+
reordered_gm = GraphModule(gm, reordered_graph)
278+
reordered_input_ids = get_input_id_node(reordered_gm)
279+
280+
with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \
281+
patch.object(_dist, 'get_rank', return_value=0):
282+
shard_tensor_node(reordered_gm, reordered_input_ids)
283+
284+
reordered_gm.graph.lint()

0 commit comments

Comments
 (0)