-
Notifications
You must be signed in to change notification settings - Fork 592
[SOT] Mark dynamic dims by type annotations #2771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
07ea00e
6bb1b67
ac71bee
bb94e72
5002ccc
1f0bcdb
00a57f6
4665f3e
1817909
604a0fe
6f7bc21
2388d5a
9d59c37
fa7ede7
f6021a2
1b70c1d
cb2915a
d8f92e6
fd94562
e92b201
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from __future__ import annotations | ||
|
||
import dataclasses | ||
import typing | ||
from abc import abstractmethod | ||
from collections.abc import Callable | ||
from functools import partial | ||
from typing import Annotated, Any, TypeVar, Union, get_origin, get_type_hints | ||
|
||
import paddle | ||
from paddle import Tensor | ||
from paddleformers.utils.log import logger | ||
from typing_extensions import TypeAlias | ||
|
||
T = TypeVar("T") | ||
U = TypeVar("U") | ||
|
||
Accessor: TypeAlias = Callable[[T], U] | ||
|
||
|
||
class DynamicDims: | ||
def __init__(self, dims: int | tuple[int]): | ||
self.dims = dims if isinstance(dims, tuple) else (dims,) | ||
|
||
def __repr__(self): | ||
return f"DynamicDims({self.dims})" | ||
|
||
|
||
class DynamicDimTypeResolver: | ||
gongshaotian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ALL_DYNAMIC_DIM_TYPE_RESOLVERS = [] | ||
|
||
@classmethod | ||
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]): | ||
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls()) | ||
return resolver_cls | ||
|
||
@abstractmethod | ||
def type_match(self, tp) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
raise NotImplementedError | ||
|
||
def resolve(self, data, data_name, tp) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 所有函数参数都需要添加类型注解 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经添加 |
||
inner_types = self.extract_inner_types(data, data_name, tp) | ||
for accessor, inner_data_name, inner_type in inner_types: | ||
self.generic_resolve(accessor(data), inner_data_name, inner_type) | ||
|
||
def generic_resolve(self, data, data_name, tp) -> None: | ||
# assert isinstance(data, tp), f"Expected {data_name} has type {tp}, but got {type(data)}" | ||
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS: | ||
if resolver.type_match(tp): | ||
return resolver.resolve(data, data_name, tp) | ||
runtime_tp = type(data) | ||
if runtime_tp is not tp and resolver.type_match(runtime_tp): | ||
return resolver.resolve(data, data_name, runtime_tp) | ||
else: | ||
logger.debug(f"No resolver found for type {tp} and data {data_name}") | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
def type_match(self, tp) -> bool: | ||
return dataclasses.is_dataclass(tp) and isinstance(tp, type) | ||
|
||
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
type_hints = get_type_hints(tp, include_extras=True) | ||
return [ # type: ignore | ||
( | ||
# bind name by partial to avoid capture wrong free vars | ||
partial(lambda name, dt: getattr(dt, name), field.name), | ||
f"{data_name}.{field.name}", | ||
type_hints[field.name], | ||
) | ||
for field in dataclasses.fields(tp) | ||
] | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
def type_match(self, tp) -> bool: | ||
return get_origin(tp) is Union and len(tp.__args__) == 2 and tp.__args__[1] is type(None) # noqa: E721 | ||
|
||
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
if data is None: | ||
return [] | ||
inner_type = tp.__args__[0] | ||
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class ListDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
def type_match(self, tp) -> bool: | ||
return get_origin(tp) is list | ||
|
||
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
if not data: | ||
return [] | ||
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any | ||
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__" | ||
|
||
def type_match(self, tp) -> bool: | ||
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) | ||
|
||
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) | ||
if isinstance(fields, str): | ||
raise TypeError( | ||
f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}" | ||
) | ||
inner_types_dict = typing.get_type_hints(tp) | ||
return [ | ||
(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type) | ||
for field_name, inner_type in inner_types_dict.items() | ||
] | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
def type_match(self, tp) -> bool: | ||
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor | ||
|
||
def resolve(self, data, data_name, tp) -> None: | ||
base_type, *metadata = typing.get_args(tp) | ||
# Filter out DynamicDims instances | ||
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)] | ||
if not dynamic_dims: | ||
return | ||
if len(dynamic_dims) > 1: | ||
raise ValueError("Multiple DynamicDims annotations found. Only one is allowed.") | ||
dynamic_dims = dynamic_dims[0].dims | ||
if not isinstance(data, Tensor): | ||
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}") | ||
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}") | ||
paddle.jit.marker.dynamic_dims(data, dynamic_dims) | ||
|
||
|
||
@DynamicDimTypeResolver.register_resolver | ||
class TensorImplicitFirstDimOnlyDynamicDimTypeResolver(DynamicDimTypeResolver): | ||
def type_match(self, tp) -> bool: | ||
return tp is Tensor | ||
|
||
def resolve(self, data, data_name, tp) -> None: | ||
# Tensor annotation has implicit dynamic_dims=(0, ) | ||
dynamic_dims = (0,) | ||
if not isinstance(data, Tensor): | ||
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}") | ||
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}") | ||
paddle.jit.marker.dynamic_dims(data, dynamic_dims) | ||
|
||
|
||
def resolve_dynamic_dims(arg, arg_name, annotation): | ||
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
from .append_attn_backend import AppendAttentionBackend | ||
from .attention import Attention | ||
from .attention_selecter import get_attention_backend | ||
from .base_attention_backend import AttentionBackend | ||
from .block_multihead_attn_backend import BlockAttentionBackend | ||
|
@@ -32,4 +33,5 @@ | |
"FlashAttentionBackend", | ||
"IluvatarAttnBackend", | ||
"BlockAttentionBackend", | ||
"Attention", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加上这个我记得是会有循环引用的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不会有的,除非是之前设计不合理导致的 |
||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个默认是 no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
框架里默认特化 1,即
"1"
,FD 里默认是不特化,是"no"