Skip to content

Commit babbb03

Browse files
jirmasekJan Jirmasek
and
Jan Jirmasek
authored
Graph Pass: scaled_dot_product_attention_sliced_q (#2418)
For longer Q sequence lengths (typically >1024), it's beneficial to calculate the attention by an algorithm (inspired by Lazy Softmax) that is processing Q in chunks. The overall memory usage and execution time (given it's executed concurrently, e.g. on ANE) should be better, and in certain cases when models encounter OOMs for longer sequence lengths, models using this algorithm still work. This PR implements a new graph pass that can optionally transform the MIL operation `ios18.scaled_dot_product_attention` into a set of operations calculating the attention by chunks of Q. Parameters of the new graph pass: * `min_seq_length` (default: 1280) - the original MIL operation will only be transformed if the sequence length of Q is greater than or equal to this value. * `seq_length_divider` (default: 16) - defines the size of chunks (based on: `chunk_size = sequence_length / seq_length_divider`) Example of performance of Depth-Anything model running on ANE: * original: execution time: 131.55 ms memory usage: 169.67 MB * with transformations applied by this graph pass: execution time: 86.84 ms memory usage: 93.34 MB --------- Co-authored-by: Jan Jirmasek <jjirmasek@apple.com>
1 parent b919d0c commit babbb03

File tree

4 files changed

+327
-0
lines changed

4 files changed

+327
-0
lines changed

coremltools/converters/mil/mil/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@
4444
optimize_state,
4545
optimize_tensor_operation,
4646
preprocess,
47+
transformer,
4748
symbol_transform,
4849
)
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5+
6+
from typing import ClassVar, List, Tuple
7+
8+
import numpy as np
9+
10+
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
11+
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
12+
from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target
13+
from coremltools.converters.mil.mil import Builder as mb
14+
from coremltools.converters.mil.mil import types
15+
16+
from coremltools import _logger as logger
17+
18+
19+
@register_pass(namespace="common")
20+
class scaled_dot_product_attention_sliced_q(AbstractGraphPass):
21+
"""
22+
Replace the ios18.scaled_dot_product_attention operation with a memory efficient
23+
implementation of attention calculation based on slicing Q. The benefits are clearly
24+
visible for higher Q sequence lengths, though.
25+
26+
Graph pass options:
27+
- min_seq_length: int
28+
Only operations working with Q of sequence length greater or equal to this value will be transformed.
29+
- seq_length_divider: int
30+
Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider)
31+
"""
32+
33+
_DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280
34+
_DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16
35+
36+
_min_seq_length: int
37+
_seq_length_divider: int
38+
39+
def __init__(self):
40+
super().__init__()
41+
self._min_seq_length = self._DEFAULT_MIN_SEQ_LENGTH
42+
self._seq_length_divider = self._DEFAULT_SEQ_LENGTH_DIVIDER
43+
44+
@property
45+
def min_seq_length(self) -> int:
46+
return self._min_seq_length
47+
48+
@min_seq_length.setter
49+
def min_seq_length(self, length: int) -> None:
50+
if not isinstance(length, int):
51+
raise ValueError("pass option min_seq_length must be an int")
52+
if length < 0:
53+
raise ValueError("pass option min_seq_length must be >= 0")
54+
self._min_seq_length = length
55+
56+
@property
57+
def seq_length_divider(self) -> int:
58+
return self._seq_length_divider
59+
60+
@seq_length_divider.setter
61+
def seq_length_divider(self, divider: int) -> None:
62+
if not isinstance(divider, int):
63+
raise ValueError("pass option seq_length_divider must be an int")
64+
if divider < 1:
65+
raise ValueError("pass option seq_length_divider must be >= 1")
66+
self._seq_length_divider = divider
67+
68+
def apply(self, prog):
69+
for f in prog.functions.values():
70+
if f.opset_version < target.iOS18:
71+
logger.debug(f"ignoring block '{f.name}', target {f.opset_version} (required min iOS18)")
72+
return
73+
74+
for op in list(f.operations):
75+
if op.op_type == "scaled_dot_product_attention":
76+
self._replace_scaled_dot_product_attention(op)
77+
78+
@staticmethod
79+
def _get_input_vars(op):
80+
mandatory_params = ["query", "key", "value"]
81+
inputs = {}
82+
for param in mandatory_params:
83+
inputs[param] = op.inputs.get(param)
84+
if inputs[param] is None:
85+
raise ValueError(f"operation 'scaled_dot_product_attention': mandatory input '{param}' not present")
86+
return tuple([inputs[param] for param in mandatory_params]) + (op.inputs.get("attn_mask"),)
87+
88+
@staticmethod
89+
def _split_to_chunks(seq_length: int, count: int) -> List[Tuple[int, int]]:
90+
chunk_size = max(seq_length // count, 1)
91+
remainder = seq_length % count
92+
93+
result = []
94+
chunk_start = 0
95+
for i in range(count):
96+
if chunk_start >= seq_length:
97+
break
98+
chunk_end = chunk_start + chunk_size + (1 if i < remainder else 0)
99+
result.append((chunk_start, chunk_end))
100+
chunk_start = chunk_end
101+
102+
return result
103+
104+
def _replace_scaled_dot_product_attention(self, op):
105+
q, k, v, mask = self._get_input_vars(op)
106+
107+
q_size = len(q.shape)
108+
q_seq_length = q.shape[-2]
109+
if q_seq_length < self._min_seq_length:
110+
logger.debug(
111+
f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}"
112+
)
113+
return
114+
115+
dims = q.shape[-1]
116+
normalize_factor = float(dims) ** -0.5
117+
118+
q_dtype = types.nptype_from_builtin(type(q.dtype()))
119+
120+
chunks = self._split_to_chunks(q_seq_length, self._seq_length_divider)
121+
122+
concat_out = None
123+
with op.enclosing_block:
124+
if mask is not None:
125+
if mask.dtype == types.bool:
126+
cond_out = mb.logical_not(x=mask, before_op=op)
127+
mask_zeros = mb.const(val=np.zeros(mask.shape, dtype=q_dtype), before_op=op)
128+
mask_float = mb.select(cond=cond_out, a=q_dtype(-np.inf), b=mask_zeros, before_op=op)
129+
else:
130+
mask_float = mask
131+
132+
for chunk_start, chunk_end in chunks:
133+
# Get a chunk of Q.
134+
slice_begin = [0] * (q_size - 2) + [chunk_start, 0]
135+
slice_end = list(q.shape[:-2] + (chunk_end, dims))
136+
slice_end_mask = tuple([True] * (q_size - 2) + [False, True])
137+
slice_out = mb.slice_by_index(
138+
x=q,
139+
begin=slice_begin,
140+
end=slice_end,
141+
end_mask=slice_end_mask,
142+
before_op=op,
143+
)
144+
145+
# Calculate chunk of Q x KT
146+
matmul_out = mb.matmul(x=slice_out, y=k, transpose_x=False, transpose_y=True, before_op=op)
147+
mul_out = mb.mul(x=matmul_out, y=np.array(normalize_factor, dtype=q_dtype), before_op=op)
148+
149+
# Apply the attention mask.
150+
if mask is not None:
151+
if mask.shape[-2] == 1:
152+
mul_out = mb.add(x=mul_out, y=mask_float, before_op=op)
153+
else:
154+
mask_out = mb.slice_by_index(
155+
x=mask_float,
156+
begin=[chunk_start, 0],
157+
end=[chunk_end, mask.shape[-1]],
158+
end_mask=[False, True],
159+
before_op=op,
160+
)
161+
mul_out = mb.add(x=mul_out, y=mask_out, before_op=op)
162+
163+
# Calculate softmax of the product.
164+
softmax_out = mb.softmax(x=mul_out, axis=-1, before_op=op)
165+
166+
# Calculate the chunk of attention.
167+
matmul_v_out = mb.matmul(
168+
x=softmax_out,
169+
y=v,
170+
transpose_x=False,
171+
transpose_y=False,
172+
before_op=op,
173+
)
174+
175+
# Add the chunk of attention to the result value.
176+
concat_values = [concat_out] if concat_out is not None else []
177+
concat_out = mb.concat(values=concat_values + [matmul_v_out], axis=-2, interleave=False, before_op=op)
178+
179+
# Remove the original SDPA operation.
180+
op.enclosing_block.replace_uses_of_var_after_op(
181+
anchor_op=op,
182+
old_var=op.outputs[0],
183+
new_var=concat_out,
184+
)
185+
op.enclosing_block.remove_ops([op])

coremltools/converters/mil/mil/passes/tests/test_passes.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import itertools
88
import unittest
99

10+
from typing import ClassVar, Dict, List, Optional
11+
1012
import numpy as np
1113
import pytest
1214
import torch
@@ -38,6 +40,7 @@
3840
get_op_types_in_program,
3941
)
4042
from coremltools.models.utils import _macos_version
43+
from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil
4144

4245
np.random.seed(1984)
4346
_VALIDATE_MODEL = True
@@ -7371,3 +7374,133 @@ def prog(x, y, z):
73717374

73727375
apply_pass_and_basic_check(prog, "common::fuse_stack_split")
73737376
assert get_op_types_in_program(prog) == ["stack", "split"] + ["squeeze"] * 3
7377+
7378+
7379+
class TestScaledDotProductAttentionSlicedQ:
7380+
7381+
class AttentionPyTorch(torch.nn.Module):
7382+
@staticmethod
7383+
def forward(q, k, v, attn_mask=None):
7384+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
7385+
7386+
@staticmethod
7387+
def _get_example_inputs(
7388+
shape_size: int = 3,
7389+
qkv_same_shape: bool = True,
7390+
dtype: torch.dtype = torch.float16,
7391+
attn_mask_dtype: Optional[torch.dtype] = None,
7392+
):
7393+
batches, seq_length, dimensions = 4, 256, 768
7394+
q_shape = (batches, seq_length, dimensions)
7395+
kv_shape = q_shape if qkv_same_shape else (batches, seq_length - 16, dimensions)
7396+
if shape_size > 3:
7397+
q_shape = tuple([1] * (shape_size - len(q_shape)) + list(q_shape))
7398+
kv_shape = tuple([1] * (shape_size - len(kv_shape)) + list(kv_shape))
7399+
inputs = {
7400+
"q": torch.rand(q_shape, dtype=dtype),
7401+
"k": torch.rand(kv_shape, dtype=dtype),
7402+
"v": torch.rand(kv_shape, dtype=dtype),
7403+
}
7404+
if attn_mask_dtype is not None:
7405+
if attn_mask_dtype == torch.bool:
7406+
inputs["attn_mask"] = torch.randint(0, 2, (seq_length, seq_length), dtype=torch.bool)
7407+
else:
7408+
inputs["attn_mask"] = torch.randn((seq_length, seq_length), dtype=dtype)
7409+
return inputs
7410+
7411+
@staticmethod
7412+
def _get_trace_coreml_inputs(example_inputs: Dict[str, torch.Tensor]):
7413+
model_inputs = [example_inputs[key] for key in ["q", "k", "v"]]
7414+
if "attn_mask" in example_inputs:
7415+
model_inputs.append(example_inputs["attn_mask"])
7416+
7417+
coreml_model_inputs = []
7418+
for key in ["q", "k", "v", "attn_mask"]:
7419+
if key in example_inputs:
7420+
dtype = example_inputs[key].numpy().dtype
7421+
if dtype == bool:
7422+
dtype = np.float32
7423+
coreml_model_inputs.append(ct.TensorType(key, shape=example_inputs[key].shape, dtype=dtype))
7424+
7425+
return model_inputs, coreml_model_inputs
7426+
7427+
def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):
7428+
pipeline_1 = ct.PassPipeline.DEFAULT
7429+
7430+
pipeline_2 = ct.PassPipeline.DEFAULT
7431+
pipeline_2.append_pass("common::scaled_dot_product_attention_sliced_q")
7432+
7433+
pipeline_3 = ct.PassPipeline.DEFAULT
7434+
pipeline_3.append_pass("common::scaled_dot_product_attention_sliced_q")
7435+
pipeline_3.set_options("common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256})
7436+
7437+
pipeline_4 = ct.PassPipeline.DEFAULT
7438+
pipeline_4.append_pass("common::scaled_dot_product_attention_sliced_q")
7439+
pipeline_4.set_options(
7440+
"common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256, "seq_length_divider": 32}
7441+
)
7442+
7443+
model = self.AttentionPyTorch()
7444+
model_inputs, coreml_model_inputs = self._get_trace_coreml_inputs(example_inputs)
7445+
7446+
coreml_models = [
7447+
ct.convert(
7448+
torch.jit.trace(model, model_inputs).eval(),
7449+
inputs=coreml_model_inputs,
7450+
minimum_deployment_target=ct.target.iOS18,
7451+
convert_to="mlprogram",
7452+
compute_units=ct.ComputeUnit.ALL,
7453+
skip_model_load=False,
7454+
pass_pipeline=pipeline,
7455+
)
7456+
for pipeline in [pipeline_1, pipeline_2, pipeline_3, pipeline_4]
7457+
]
7458+
7459+
model_specs = [coreml_model.get_spec() for coreml_model in coreml_models]
7460+
progs = []
7461+
for i in range(len(coreml_models)):
7462+
progs.append(
7463+
_milproto_to_pymil(
7464+
model_spec=model_specs[i],
7465+
specification_version=model_specs[i].specificationVersion,
7466+
file_weights_dir=coreml_models[i].weights_dir,
7467+
)
7468+
)
7469+
7470+
ops_counts = [len(prog.functions["main"].operations) for prog in progs]
7471+
7472+
assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype)
7473+
assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length
7474+
assert ops_counts[2] >= 6 * 16 # 6 ops (without consts) per slice
7475+
assert ops_counts[3] >= 6 * 32
7476+
7477+
predict_inputs = copy.deepcopy(example_inputs)
7478+
if "attn_mask" in predict_inputs:
7479+
predict_inputs["attn_mask"] = predict_inputs["attn_mask"].to(dtype=torch.float32)
7480+
7481+
outputs = [list(coreml_model.predict(predict_inputs).values())[0] for coreml_model in coreml_models]
7482+
7483+
for i in range(1, len(outputs)):
7484+
assert outputs[0].shape == outputs[i].shape
7485+
np.testing.assert_allclose(outputs[0], outputs[i], rtol=0.01)
7486+
7487+
def test_scaled_dot_product_attention_sliced(self):
7488+
# Confirm the basic scenario.
7489+
example_inputs = self._get_example_inputs()
7490+
self.verify_sdpa_outputs(example_inputs)
7491+
7492+
# Confirm sdpa with Q, K and V as 4D tensors.
7493+
example_inputs = self._get_example_inputs(shape_size=4)
7494+
self.verify_sdpa_outputs(example_inputs)
7495+
7496+
# Confirm sdpa with attn_mask as a bias.
7497+
example_inputs = self._get_example_inputs(attn_mask_dtype=torch.float16)
7498+
self.verify_sdpa_outputs(example_inputs)
7499+
7500+
# Confirm sdpa with attn_mask as boolean flags.
7501+
example_inputs = self._get_example_inputs(attn_mask_dtype=torch.bool)
7502+
self.verify_sdpa_outputs(example_inputs)
7503+
7504+
# Confirm sdpa works well with different shapes for Q and K & V.
7505+
example_inputs = self._get_example_inputs(qkv_same_shape=False)
7506+
self.verify_sdpa_outputs(example_inputs)

docs/source/coremltools.converters.mil.mil.passes.defs.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,11 @@ symbol_transform
147147
.. automodule:: coremltools.converters.mil.mil.passes.defs.symbol_transform
148148

149149
.. autoclass:: materialize_symbolic_shape_program
150+
151+
152+
transformer
153+
---------------------------------------------------------
154+
155+
.. automodule:: coremltools.converters.mil.mil.passes.defs.transformer
156+
157+
.. autoclass:: scaled_dot_product_attention_sliced_q

0 commit comments

Comments
 (0)