|
| 1 | +# Copyright (c) 2024, Apple Inc. All rights reserved. |
| 2 | +# |
| 3 | +# Use of this source code is governed by a BSD-3-clause license that can be |
| 4 | +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause |
| 5 | + |
| 6 | +from typing import ClassVar, List, Tuple |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass |
| 11 | +from coremltools.converters.mil.mil.passes.pass_registry import register_pass |
| 12 | +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target |
| 13 | +from coremltools.converters.mil.mil import Builder as mb |
| 14 | +from coremltools.converters.mil.mil import types |
| 15 | + |
| 16 | +from coremltools import _logger as logger |
| 17 | + |
| 18 | + |
| 19 | +@register_pass(namespace="common") |
| 20 | +class scaled_dot_product_attention_sliced_q(AbstractGraphPass): |
| 21 | + """ |
| 22 | + Replace the ios18.scaled_dot_product_attention operation with a memory efficient |
| 23 | + implementation of attention calculation based on slicing Q. The benefits are clearly |
| 24 | + visible for higher Q sequence lengths, though. |
| 25 | +
|
| 26 | + Graph pass options: |
| 27 | + - min_seq_length: int |
| 28 | + Only operations working with Q of sequence length greater or equal to this value will be transformed. |
| 29 | + - seq_length_divider: int |
| 30 | + Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider) |
| 31 | + """ |
| 32 | + |
| 33 | + _DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280 |
| 34 | + _DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16 |
| 35 | + |
| 36 | + _min_seq_length: int |
| 37 | + _seq_length_divider: int |
| 38 | + |
| 39 | + def __init__(self): |
| 40 | + super().__init__() |
| 41 | + self._min_seq_length = self._DEFAULT_MIN_SEQ_LENGTH |
| 42 | + self._seq_length_divider = self._DEFAULT_SEQ_LENGTH_DIVIDER |
| 43 | + |
| 44 | + @property |
| 45 | + def min_seq_length(self) -> int: |
| 46 | + return self._min_seq_length |
| 47 | + |
| 48 | + @min_seq_length.setter |
| 49 | + def min_seq_length(self, length: int) -> None: |
| 50 | + if not isinstance(length, int): |
| 51 | + raise ValueError("pass option min_seq_length must be an int") |
| 52 | + if length < 0: |
| 53 | + raise ValueError("pass option min_seq_length must be >= 0") |
| 54 | + self._min_seq_length = length |
| 55 | + |
| 56 | + @property |
| 57 | + def seq_length_divider(self) -> int: |
| 58 | + return self._seq_length_divider |
| 59 | + |
| 60 | + @seq_length_divider.setter |
| 61 | + def seq_length_divider(self, divider: int) -> None: |
| 62 | + if not isinstance(divider, int): |
| 63 | + raise ValueError("pass option seq_length_divider must be an int") |
| 64 | + if divider < 1: |
| 65 | + raise ValueError("pass option seq_length_divider must be >= 1") |
| 66 | + self._seq_length_divider = divider |
| 67 | + |
| 68 | + def apply(self, prog): |
| 69 | + for f in prog.functions.values(): |
| 70 | + if f.opset_version < target.iOS18: |
| 71 | + logger.debug(f"ignoring block '{f.name}', target {f.opset_version} (required min iOS18)") |
| 72 | + return |
| 73 | + |
| 74 | + for op in list(f.operations): |
| 75 | + if op.op_type == "scaled_dot_product_attention": |
| 76 | + self._replace_scaled_dot_product_attention(op) |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def _get_input_vars(op): |
| 80 | + mandatory_params = ["query", "key", "value"] |
| 81 | + inputs = {} |
| 82 | + for param in mandatory_params: |
| 83 | + inputs[param] = op.inputs.get(param) |
| 84 | + if inputs[param] is None: |
| 85 | + raise ValueError(f"operation 'scaled_dot_product_attention': mandatory input '{param}' not present") |
| 86 | + return tuple([inputs[param] for param in mandatory_params]) + (op.inputs.get("attn_mask"),) |
| 87 | + |
| 88 | + @staticmethod |
| 89 | + def _split_to_chunks(seq_length: int, count: int) -> List[Tuple[int, int]]: |
| 90 | + chunk_size = max(seq_length // count, 1) |
| 91 | + remainder = seq_length % count |
| 92 | + |
| 93 | + result = [] |
| 94 | + chunk_start = 0 |
| 95 | + for i in range(count): |
| 96 | + if chunk_start >= seq_length: |
| 97 | + break |
| 98 | + chunk_end = chunk_start + chunk_size + (1 if i < remainder else 0) |
| 99 | + result.append((chunk_start, chunk_end)) |
| 100 | + chunk_start = chunk_end |
| 101 | + |
| 102 | + return result |
| 103 | + |
| 104 | + def _replace_scaled_dot_product_attention(self, op): |
| 105 | + q, k, v, mask = self._get_input_vars(op) |
| 106 | + |
| 107 | + q_size = len(q.shape) |
| 108 | + q_seq_length = q.shape[-2] |
| 109 | + if q_seq_length < self._min_seq_length: |
| 110 | + logger.debug( |
| 111 | + f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}" |
| 112 | + ) |
| 113 | + return |
| 114 | + |
| 115 | + dims = q.shape[-1] |
| 116 | + normalize_factor = float(dims) ** -0.5 |
| 117 | + |
| 118 | + q_dtype = types.nptype_from_builtin(type(q.dtype())) |
| 119 | + |
| 120 | + chunks = self._split_to_chunks(q_seq_length, self._seq_length_divider) |
| 121 | + |
| 122 | + concat_out = None |
| 123 | + with op.enclosing_block: |
| 124 | + if mask is not None: |
| 125 | + if mask.dtype == types.bool: |
| 126 | + cond_out = mb.logical_not(x=mask, before_op=op) |
| 127 | + mask_zeros = mb.const(val=np.zeros(mask.shape, dtype=q_dtype), before_op=op) |
| 128 | + mask_float = mb.select(cond=cond_out, a=q_dtype(-np.inf), b=mask_zeros, before_op=op) |
| 129 | + else: |
| 130 | + mask_float = mask |
| 131 | + |
| 132 | + for chunk_start, chunk_end in chunks: |
| 133 | + # Get a chunk of Q. |
| 134 | + slice_begin = [0] * (q_size - 2) + [chunk_start, 0] |
| 135 | + slice_end = list(q.shape[:-2] + (chunk_end, dims)) |
| 136 | + slice_end_mask = tuple([True] * (q_size - 2) + [False, True]) |
| 137 | + slice_out = mb.slice_by_index( |
| 138 | + x=q, |
| 139 | + begin=slice_begin, |
| 140 | + end=slice_end, |
| 141 | + end_mask=slice_end_mask, |
| 142 | + before_op=op, |
| 143 | + ) |
| 144 | + |
| 145 | + # Calculate chunk of Q x KT |
| 146 | + matmul_out = mb.matmul(x=slice_out, y=k, transpose_x=False, transpose_y=True, before_op=op) |
| 147 | + mul_out = mb.mul(x=matmul_out, y=np.array(normalize_factor, dtype=q_dtype), before_op=op) |
| 148 | + |
| 149 | + # Apply the attention mask. |
| 150 | + if mask is not None: |
| 151 | + if mask.shape[-2] == 1: |
| 152 | + mul_out = mb.add(x=mul_out, y=mask_float, before_op=op) |
| 153 | + else: |
| 154 | + mask_out = mb.slice_by_index( |
| 155 | + x=mask_float, |
| 156 | + begin=[chunk_start, 0], |
| 157 | + end=[chunk_end, mask.shape[-1]], |
| 158 | + end_mask=[False, True], |
| 159 | + before_op=op, |
| 160 | + ) |
| 161 | + mul_out = mb.add(x=mul_out, y=mask_out, before_op=op) |
| 162 | + |
| 163 | + # Calculate softmax of the product. |
| 164 | + softmax_out = mb.softmax(x=mul_out, axis=-1, before_op=op) |
| 165 | + |
| 166 | + # Calculate the chunk of attention. |
| 167 | + matmul_v_out = mb.matmul( |
| 168 | + x=softmax_out, |
| 169 | + y=v, |
| 170 | + transpose_x=False, |
| 171 | + transpose_y=False, |
| 172 | + before_op=op, |
| 173 | + ) |
| 174 | + |
| 175 | + # Add the chunk of attention to the result value. |
| 176 | + concat_values = [concat_out] if concat_out is not None else [] |
| 177 | + concat_out = mb.concat(values=concat_values + [matmul_v_out], axis=-2, interleave=False, before_op=op) |
| 178 | + |
| 179 | + # Remove the original SDPA operation. |
| 180 | + op.enclosing_block.replace_uses_of_var_after_op( |
| 181 | + anchor_op=op, |
| 182 | + old_var=op.outputs[0], |
| 183 | + new_var=concat_out, |
| 184 | + ) |
| 185 | + op.enclosing_block.remove_ops([op]) |
0 commit comments