Skip to content

Commit dc5bd51

Browse files
Add AutoSP to DeepSpeed
--------- Signed-off-by: Neel Dani <neeldani98@gmail.com> Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
1 parent 285cae3 commit dc5bd51

File tree

12 files changed

+765
-44
lines changed

12 files changed

+765
-44
lines changed

deepspeed/compile/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
# DeepSpeed Team
55

6+
from typing import List, Optional, Literal
67
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
78

9+
PassName = Literal["z1", "z3", "autosp"]
10+
811

912
class CompileConfig(DeepSpeedConfigModel):
1013
""" Configure compile settings """
@@ -53,3 +56,6 @@ class CompileConfig(DeepSpeedConfigModel):
5356

5457
keep_all_input_tensors: bool = False
5558
""" Keep real values for all input tensors in InputStorage instead of using dummy values """
59+
60+
passes: Optional[List[PassName]] = None
61+
""" Composes different optimizations. """

deepspeed/compile/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
#########################################
7+
# AUTOSP
8+
#########################################
9+
AUTOSP_INPUT_ID_KEY = "input_id"
10+
AUTOSP_LABEL_ID_KEY = "label_id"
11+
AUTOSP_POSITION_ID_KEY = "position_id"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .all_to_all import all_to_all
7+
from . import sp_dp_registry
8+
9+
__all__ = ["all_to_all", "sp_dp_registry"]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
import deepspeed.comm as dist
8+
from .sp_dp_registry import get_group, is_setup, sp_size
9+
10+
11+
@torch.library.custom_op("autosp::all_to_all", mutates_args=())
12+
def all_to_all(
13+
input: torch.Tensor,
14+
scatter_idx: int,
15+
gather_idx: int,
16+
name: str,
17+
) -> torch.Tensor:
18+
"""
19+
All-to-all collective for SDPA tensors [B, N, S, H].
20+
21+
For QKV (scatter_idx=1, gather_idx=2):
22+
[B, N, S/P, H] -> [B, N/P, S, H]
23+
For O (scatter_idx=2, gather_idx=1):
24+
[B, N/P, S, H] -> [B, N, S/P, H]
25+
"""
26+
assert is_setup(), 'Incorrect initialization of SP/DP mesh.'
27+
B, dim1, dim2, H = input.shape
28+
gid = dist.get_rank() // sp_size()
29+
group = get_group(gid)
30+
31+
if scatter_idx == 1:
32+
N, local_S = dim1, dim2
33+
input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H)
34+
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
35+
36+
output = torch.empty_like(input_t)
37+
dist.all_to_all_single(output, input_t, group=group)
38+
39+
output = output.permute(1, 2, 0, 3, 4).contiguous()
40+
output = output.reshape(B, N // sp_size(), sp_size() * local_S, H)
41+
else:
42+
local_N, S = dim1, dim2
43+
input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H)
44+
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
45+
46+
output = torch.empty_like(input_t)
47+
dist.all_to_all_single(output, input_t, group=group)
48+
49+
output = output.permute(1, 0, 2, 3, 4).contiguous()
50+
output = output.reshape(B, sp_size() * local_N, S // sp_size(), H)
51+
52+
return output
53+
54+
55+
@torch.library.register_fake("autosp::all_to_all")
56+
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):
57+
B, dim1, dim2, H = input.shape
58+
if scatter_idx == 1:
59+
return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H)
60+
else:
61+
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)
62+
63+
64+
def _all_to_all_backward_setup(ctx, inputs, output):
65+
_, scatter_idx, gather_idx, name = inputs
66+
ctx.scatter_idx = gather_idx
67+
ctx.gather_idx = scatter_idx
68+
ctx.name = name + "_grad"
69+
70+
71+
def _all_to_all_backward(ctx, grad):
72+
return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None)
73+
74+
75+
torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import deepspeed.comm as dist
7+
8+
GROUP_REGISTRY = {} # int -> dist.ProcessGroup
9+
10+
11+
def register_groups(groups):
12+
"""groups: List[List[int]], e.g. [[0,1],[2,3]]"""
13+
for gid, ranks in enumerate(groups):
14+
if gid not in GROUP_REGISTRY:
15+
GROUP_REGISTRY[gid] = dist.new_group(ranks)
16+
17+
18+
def get_group(gid: int):
19+
return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group()
20+
21+
22+
def get_registry():
23+
return GROUP_REGISTRY
24+
25+
26+
def is_setup():
27+
return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False
28+
29+
30+
def extract_mesh_size(param_dict):
31+
sp_size = param_dict.get('sequence_parallel_size', 1)
32+
assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE'
33+
dp_size = dist.get_world_size() // sp_size
34+
35+
return sp_size, dp_size
36+
37+
38+
def sp_size():
39+
assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.'
40+
41+
return GROUP_REGISTRY['SP_SIZE']
42+
43+
44+
def dp_size():
45+
assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly'
46+
47+
return GROUP_REGISTRY['DP_SIZE']
48+
49+
50+
def populate_registry(SP_SIZE, DP_SIZE):
51+
""" Populate rank to SP/DP mesh index. """
52+
53+
if GROUP_REGISTRY.get('is_reg', False):
54+
return
55+
56+
group_listing = []
57+
offset = 0
58+
for _ in range(DP_SIZE):
59+
group_listing.append([i + offset for i in range(SP_SIZE)])
60+
offset += SP_SIZE
61+
62+
register_groups(group_listing)
63+
64+
## Extraneous metadata required for proper instatiation. ##
65+
GROUP_REGISTRY['SP_SIZE'] = SP_SIZE
66+
GROUP_REGISTRY['DP_SIZE'] = DP_SIZE
67+
GROUP_REGISTRY['is_reg'] = True

deepspeed/compile/fx.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
# DeepSpeed Team
55

6-
from typing import Callable, Any, List, Dict
6+
from typing import Callable, Any, List, Dict, Optional
77
from collections import defaultdict
88

99
import torch
10-
from torch.fx import Node, Graph
10+
from torch.fx import Node, Graph, GraphModule
1111

1212
from .util import get_last_uses
1313

@@ -138,3 +138,32 @@ def free_tensors(tensors: List[torch.Tensor]):
138138

139139
# Python version for debugging
140140
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
141+
142+
143+
def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]:
144+
for node in gm.graph.nodes:
145+
if node.name == name:
146+
return node
147+
return None
148+
149+
150+
def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]:
151+
return node.meta.get("val") or node.meta.get("example_value")
152+
153+
154+
def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]:
155+
input_id_node = None
156+
for node in gm.graph.nodes:
157+
# https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048
158+
tensor_dict = node.meta.get('tensor_dict')
159+
if tensor_dict and tensor_dict.get('tag') == tag:
160+
input_id_node = node
161+
break
162+
return input_id_node
163+
164+
165+
def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None):
166+
exclude = exclude or []
167+
to_replace = [u for u in node.users if u not in exclude]
168+
for user in to_replace:
169+
user.replace_input_with(node, replacement)

deepspeed/compile/init_sp.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
from torch.fx import GraphModule
8+
from torch._functorch.partitioners import OpTypes
9+
from .passes.sp_compile import apply_autosp
10+
from .custom_ops.sp_dp_registry import extract_mesh_size
11+
12+
MATMUL_OPS = [
13+
torch.ops.aten.mm.default,
14+
torch.ops.aten.bmm.default,
15+
torch.ops.aten.addmm.default,
16+
torch.ops.aten._scaled_mm.default,
17+
]
18+
19+
20+
def register_selective_matmul_recomputation():
21+
"""Mark matrix multiplication operations for recomputation during backward pass.
22+
23+
Reduces peak memory usage for long sequences by recomputing matmuls instead
24+
of storing their activations. Attention computation dominates matmuls for long sequences.
25+
Targets mm, bmm, addmm, and scaled_mm operations.
26+
"""
27+
28+
_original_is_compute_intensive = OpTypes.is_compute_intensive
29+
30+
def is_compute_intensive_wrapper(self, node: torch.fx.Node):
31+
if node.target in MATMUL_OPS:
32+
return False
33+
return _original_is_compute_intensive(self, node)
34+
35+
OpTypes.is_compute_intensive = is_compute_intensive_wrapper
36+
37+
38+
def init_autosp(config):
39+
sp_size, dp_size = extract_mesh_size(config._param_dict)
40+
41+
def backend_fn(gm: GraphModule, real_inputs):
42+
apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size)
43+
return torch._inductor.compile(gm, real_inputs)
44+
45+
return backend_fn

0 commit comments

Comments
 (0)