Skip to content

[SOT] Mark dynamic dims by type annotations #2771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ class GraphOptimizationConfig:
cudagraph_splitting_ops = ["paddle.unified_attention"]

Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops = Optional[list[str]]
"""" Whether to use a full cuda graph for the entire forward pass rather than
cudagraph_splitting_ops: list[str] = field(default_factory=list)
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = True
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,11 +937,11 @@ def _setting_environ_variables(self):
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个默认是 no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

框架里默认特化 1,即 "1",FD 里默认是不特化,是 "no"

"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
"FLAGS_pir_interpreter_record_stream_for_gc_cache",
default="1",
"FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
),
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
Expand Down
191 changes: 191 additions & 0 deletions fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from __future__ import annotations

import dataclasses
import typing
from abc import abstractmethod
from collections.abc import Callable
from functools import partial
from typing import Annotated, Any, TypeVar, Union, get_origin, get_type_hints

import paddle
from paddle import Tensor
from paddleformers.utils.log import logger
from typing_extensions import TypeAlias

T = TypeVar("T")
U = TypeVar("U")

Accessor: TypeAlias = Callable[[T], U]


class DynamicDims:
def __init__(self, dims: int | tuple[int]):
self.dims = dims if isinstance(dims, tuple) else (dims,)

def __repr__(self):
return f"DynamicDims({self.dims})"


class DynamicDimTypeResolver:
"""
Base class for dynamic dimension type resolvers.
This class provides a mechanism to register and resolve dynamic dimensions
based on type annotations. It uses a registry pattern to allow multiple
resolvers to be registered and used in a flexible manner.
"""

ALL_DYNAMIC_DIM_TYPE_RESOLVERS = []

@classmethod
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]):
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls())
return resolver_cls

@abstractmethod
def type_match(self, tp: type[Any]) -> bool:
raise NotImplementedError

@abstractmethod
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
raise NotImplementedError

def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
inner_types = self.extract_inner_types(data, data_name, tp)
for accessor, inner_data_name, inner_type in inner_types:
self.generic_resolve(accessor(data), inner_data_name, inner_type)

def generic_resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS:
if resolver.type_match(tp):
return resolver.resolve(data, data_name, tp)
runtime_tp = type(data)
if runtime_tp is not tp and resolver.type_match(runtime_tp):
return resolver.resolve(data, data_name, runtime_tp)
else:
logger.debug(f"No resolver found for type {tp} and data {data_name}")


@DynamicDimTypeResolver.register_resolver
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return dataclasses.is_dataclass(tp) and isinstance(tp, type)

def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
type_hints = get_type_hints(tp, include_extras=True)
return [ # type: ignore
(
# bind name by partial to avoid capture wrong free vars
partial(lambda name, dt: getattr(dt, name), field.name),
f"{data_name}.{field.name}",
type_hints[field.name],
)
for field in dataclasses.fields(tp)
]


@DynamicDimTypeResolver.register_resolver
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp) -> bool:
return get_origin(tp) is Union and len(tp.__args__) == 2 and tp.__args__[1] is type(None) # noqa: E721

def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if data is None:
return []
inner_type = tp.__args__[0]
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional


@DynamicDimTypeResolver.register_resolver
class ListDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is list

def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if not data:
return []
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore


@DynamicDimTypeResolver.register_resolver
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver):
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__"

def type_match(self, tp: type[Any]) -> bool:
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)

def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
if isinstance(fields, str):
raise TypeError(
f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}"
)
inner_types_dict = typing.get_type_hints(tp)
return [
(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type)
for field_name, inner_type in inner_types_dict.items()
]


@DynamicDimTypeResolver.register_resolver
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor

def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
base_type, *metadata = typing.get_args(tp)
# Filter out DynamicDims instances
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)]
if not dynamic_dims:
return
if len(dynamic_dims) > 1:
raise ValueError("Multiple DynamicDims annotations found. Only one is allowed.")
dynamic_dims = dynamic_dims[0].dims
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)


@DynamicDimTypeResolver.register_resolver
class TensorImplicitFirstDimOnlyDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return tp is Tensor

def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
# Tensor annotation has implicit dynamic_dims=(0, )
dynamic_dims = (0,)
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)


def resolve_dynamic_dims(arg: Any, arg_name: str, annotation: type[Any]) -> None:
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation)
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,101 @@
# limitations under the License.
"""

from typing import Callable, Optional
import functools
import inspect
import types
from typing import Callable, Optional, TypeVar, get_type_hints

from paddle.jit.dy2static.utils import Backend
from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
from paddleformers.utils.log import logger
from typing_extensions import ParamSpec

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
CudaGraphPiecewiseBackend,
)
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
resolve_dynamic_dims,
)

P = ParamSpec("P")
T = TypeVar("T")


# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
def create_in_warmup_mode():
cnt = 0

def in_warmup_mode():
nonlocal cnt
cnt += 1
return cnt < 32

return in_warmup_mode


in_warmup_mode = create_in_warmup_mode()


def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
forward_fn = fn
forward_sig = inspect.signature(forward_fn)
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
unsafe_static_forward_fn = None
need_warmup = True

@functools.wraps(forward_fn)
def warmup_impl(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn, need_warmup
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
if name not in forward_type_hints:
continue
annotation = forward_type_hints[name]
resolve_dynamic_dims(arg, name, annotation)

result = static_forward_fn(self, *args, **kwargs)
original_code = forward_fn.__code__
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
original_code
]
# Check has only one graph
if len(new_guarded_codes) > 1:
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
unsafe_static_forward_fn = None
need_warmup = False
return result
# Check generated code has no break graph
new_code = new_guarded_codes[0][0][0]
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = None
need_warmup = False
return result
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
forward_fn.__name__,
forward_fn.__defaults__,
forward_fn.__closure__,
)
return result

@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal need_warmup
is_warmup = in_warmup_mode() and need_warmup
if is_warmup:
return warmup_impl(self, *args, **kwargs)
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is None:
return static_forward_fn(self, *args, **kwargs)
return unsafe_static_forward_fn(self, *args, **kwargs)

return static_forward


class GraphOptBackend:
Expand All @@ -42,10 +129,14 @@ def __init__(self, runnable: Callable, fd_config: FDConfig):
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)

# 2. Convert dynamic grpah to static graph
from paddle.jit import sot

backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend)
backend = (
ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
)
self.runnable = apply_to_static_optimization(
self.runnable.__func__,
backend,
).__get__(self.runnable.__self__)

def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph:
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .append_attn_backend import AppendAttentionBackend
from .attention import Attention
from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .block_multihead_attn_backend import BlockAttentionBackend
Expand All @@ -32,4 +33,5 @@
"FlashAttentionBackend",
"IluvatarAttnBackend",
"BlockAttentionBackend",
"Attention",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上这个我记得是会有循环引用的?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会有的,除非是之前设计不合理导致的

]
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,23 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"

# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)


class AppendAttentionBackend(AttentionBackend):
"""
AppendAttentionBackend backend implementation.
"""

__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: AppendAttentionMetadata

def __init__(
self,
fd_config: FDConfig,
Expand Down
Loading