Skip to content

Commit 0d4566d

Browse files
authored
[SOT] Add internal API paddle.jit.marker.force_dynamic to ensure function or layer run under dynamic mode (#73059)
1 parent 71215b5 commit 0d4566d

File tree

7 files changed

+163
-15
lines changed

7 files changed

+163
-15
lines changed

python/paddle/jit/marker.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import inspect
1718
from typing import (
1819
Callable,
1920
Protocol,
@@ -25,6 +26,8 @@
2526
ParamSpec,
2627
)
2728

29+
import paddle
30+
2831
from .dy2static.utils import (
2932
TransformOptions,
3033
)
@@ -123,3 +126,24 @@ def _mark_as_unified(fn, *, for_sot: bool, for_ast: bool):
123126
if fn is None:
124127
return lambda fn: _mark_as_unified(fn, for_sot=for_sot, for_ast=for_ast)
125128
return _mark_as_unified(fn, for_sot=for_sot, for_ast=for_ast)
129+
130+
131+
def force_dynamic(
132+
fn: Callable[_InputT, _RetT] | type[paddle.nn.Layer] | None = None,
133+
) -> Callable[_InputT, _RetT]:
134+
"""
135+
Mark a function or paddle.nn.Layer to be executed in dynamic mode, it will
136+
break the graph and prevent it from being converted to static mode.
137+
"""
138+
from paddle.jit import sot
139+
140+
if inspect.isclass(fn) and issubclass(fn, paddle.nn.Layer):
141+
sot.utils.paddle_api_config.add_break_graph_layer_class(fn)
142+
return fn
143+
if inspect.isfunction(fn):
144+
sot.utils.paddle_api_config.add_break_graph_function(fn)
145+
return fn
146+
147+
raise TypeError(
148+
f"Expected a callable or paddle.nn.Layer, but got {type(fn).__name__}."
149+
)

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
get_obj_stable_repr,
4545
get_static_function,
4646
hashable,
47-
is_break_graph_api,
4847
is_break_graph_tensor_methods,
4948
is_builtin_fn,
5049
is_directly_run_api,
@@ -62,6 +61,7 @@
6261
DataDependencyOperationBreak,
6362
FallbackError,
6463
FallbackInlineCallBreak,
64+
ForceBreak,
6565
InnerError,
6666
OtherInlineCallBreak,
6767
PsdbBreakReason,
@@ -70,9 +70,12 @@
7070
SotErrorBase,
7171
UnsupportedNumPyAPIBreak,
7272
UnsupportedOperationBreak,
73-
UnsupportedPaddleAPIBreak,
7473
UnsupportedRandomAPIBreak,
7574
)
75+
from ....utils.paddle_api_config import (
76+
break_graph_functions,
77+
break_graph_layer_classes,
78+
)
7679
from ..dispatcher import Dispatcher
7780
from ..guard import (
7881
FasterStringifiedExpression,
@@ -146,6 +149,33 @@ def call_function(self, /, *args, **kwargs):
146149
raise NotImplementedError("call_function is not implemented.")
147150

148151

152+
class ForceBreakCallableVariable(CallableVariable):
153+
def __init__(self, name: str, graph: FunctionGraph, tracker: Tracker):
154+
super().__init__(graph, tracker)
155+
self.name = name
156+
157+
def call_function(self, /, *args, **kwargs) -> VariableBase:
158+
raise BreakGraphError(ForceBreak(reason_str=f"Force run {self.name}"))
159+
160+
def get_py_value(self, allow_tensor=False):
161+
return self.value
162+
163+
@VariableFactory.register_from_value()
164+
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
165+
if (
166+
isinstance(value, paddle.nn.Layer)
167+
and value.__class__ in break_graph_layer_classes
168+
):
169+
return ForceBreakCallableVariable(
170+
f"Layer({value.__class__.__name__})", graph, tracker
171+
)
172+
elif hashable(value) and value in break_graph_functions:
173+
return ForceBreakCallableVariable(
174+
get_obj_stable_repr(value), graph, tracker
175+
)
176+
return None
177+
178+
149179
class FunctionVariable(CallableVariable):
150180
"""
151181
FunctionVariable is a subclass of CallableVariable used to wrap a function variable.
@@ -364,10 +394,6 @@ def __init__(
364394
super().__init__(fn, graph, tracker)
365395

366396
def call_function(self, /, *args, **kwargs):
367-
if is_break_graph_api(self.value):
368-
raise BreakGraphError(
369-
UnsupportedPaddleAPIBreak(fn_name=self.value.__name__)
370-
)
371397
return self.graph.call_paddle_api(self.value, *args, **kwargs)
372398

373399
@VariableFactory.register_from_value(

python/paddle/jit/sot/utils/exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ def __init__(
144144
)
145145

146146

147+
class ForceBreak(UnsupportedOperationBreak):
148+
def __init__(
149+
self,
150+
*,
151+
reason_str=None,
152+
file_path="",
153+
line_number=-1,
154+
):
155+
if reason_str is None:
156+
reason_str = "Force break graph execution"
157+
158+
super().__init__(
159+
reason_str=reason_str,
160+
file_path=file_path,
161+
line_number=line_number,
162+
)
163+
164+
147165
class BuiltinFunctionBreak(UnsupportedOperationBreak):
148166
"""Break reason for unsupported built-in function calls.
149167

python/paddle/jit/sot/utils/paddle_api_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@ def _get_tensor_methods():
116116
"paddle.nn.functional",
117117
}
118118

119-
break_graph_set = set()
120-
121-
119+
break_graph_functions = set()
120+
break_graph_layer_classes = set()
122121
break_graph_tensor_method = {
123122
'register_hook',
124123
'numpy',
@@ -139,8 +138,12 @@ def is_break_graph_tensor_methods(method_name):
139138
return method_name in break_graph_tensor_method
140139

141140

142-
def add_break_graph_apis(apis: list):
143-
break_graph_set.update(apis)
141+
def add_break_graph_function(fn):
142+
break_graph_functions.add(fn)
143+
144+
145+
def add_break_graph_layer_class(layer_class: type[paddle.nn.Layer]):
146+
break_graph_layer_classes.add(layer_class)
144147

145148

146149
def is_directly_run_api(api):

python/paddle/jit/sot/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ENV_STRICT_MODE,
3939
)
4040
from .paddle_api_config import (
41-
break_graph_set,
41+
break_graph_functions,
4242
paddle_api_list,
4343
paddle_api_module_prefix,
4444
)
@@ -237,7 +237,7 @@ def in_paddle_module(func):
237237

238238

239239
def is_break_graph_api(func):
240-
return func in break_graph_set
240+
return func in break_graph_functions
241241

242242

243243
def is_namedtuple_class(cls):

test/sot/test_break_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from test_case_base import TestCaseBase
1919

2020
import paddle
21-
from paddle.jit.sot.utils.paddle_api_config import add_break_graph_apis
21+
from paddle.jit.sot.utils.paddle_api_config import add_break_graph_function
2222

2323

2424
def ifelse_func(x, y):
@@ -74,7 +74,7 @@ def to_tensor_break_graph(x, y):
7474

7575
class TestToTensor(TestCaseBase):
7676
def test_simple(self):
77-
add_break_graph_apis([paddle.to_tensor])
77+
add_break_graph_function(paddle.to_tensor)
7878
x = paddle.to_tensor(2)
7979
y = paddle.to_tensor(3)
8080
self.assert_results(to_tensor_break_graph, x, y)

test/sot/test_force_dynamic.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from test_case_base import (
18+
TestCaseBase,
19+
test_instruction_translator_cache_context,
20+
)
21+
22+
import paddle
23+
from paddle.jit import sot
24+
25+
26+
class EmbeddingLayer(paddle.nn.Layer):
27+
def __init__(self):
28+
super().__init__()
29+
self.embedding = paddle.nn.Embedding(10, 10)
30+
31+
def forward(self, x):
32+
x = x + 1 - 1
33+
x = self.embedding(x)
34+
return x + 1
35+
36+
37+
def call_embedding_layer(x: paddle.Tensor, layer: paddle.nn.Layer):
38+
return layer(x)
39+
40+
41+
def call_functional_embedding(x: paddle.Tensor, weight: paddle.Tensor):
42+
x = x + 1 - 1
43+
x = paddle.nn.functional.embedding(x, weight)
44+
return x + 1
45+
46+
47+
class TestForceDynamic(TestCaseBase):
48+
def test_embedding_layer(self):
49+
paddle.jit.marker.force_dynamic(paddle.nn.Embedding)
50+
51+
layer = EmbeddingLayer()
52+
with test_instruction_translator_cache_context() as ctx:
53+
self.assertEqual(ctx.translate_count, 0)
54+
self.assert_results(
55+
call_embedding_layer,
56+
paddle.randint(0, 10, [1, 3, 224, 224], dtype='int64'),
57+
layer,
58+
)
59+
self.assertGreater(ctx.translate_count, 1)
60+
61+
sot.utils.paddle_api_config.break_graph_layer_classes.clear()
62+
63+
def test_functional_embedding(self):
64+
paddle.jit.marker.force_dynamic(paddle.nn.functional.embedding)
65+
66+
weight = paddle.randn([10, 10])
67+
x = paddle.randint(0, 10, [1, 3, 224, 224], dtype='int64')
68+
with test_instruction_translator_cache_context() as ctx:
69+
self.assertEqual(ctx.translate_count, 0)
70+
self.assert_results(call_functional_embedding, x, weight)
71+
self.assertGreater(ctx.translate_count, 1)
72+
73+
sot.utils.paddle_api_config.break_graph_functions.clear()
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

0 commit comments

Comments
 (0)