Skip to content

Commit e7007d1

Browse files
authored
[SOT] Add partial support (#72956)
1 parent 441abda commit e7007d1

File tree

3 files changed

+121
-4
lines changed

3 files changed

+121
-4
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
NumPyApiVariable,
5151
PaddleApiVariable,
5252
PaddleLayerVariable,
53+
PartialVariable,
5354
UserCodeVariable,
5455
UserDefinedFunctionVariable,
5556
UserDefinedGeneratorFunctionVariable,

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import operator
2222
import sys
2323
import types
24-
from functools import reduce
24+
from functools import partial, reduce
2525
from typing import (
2626
TYPE_CHECKING,
2727
Any,
@@ -1118,7 +1118,9 @@ def call_function(self, /, *args, **kwargs):
11181118
# do not have init function
11191119
if self.value.__init__ is object.__init__:
11201120
return VariableFactory.from_value(
1121-
new_object, self.graph, DummyTracker([self])
1121+
new_object,
1122+
self.graph,
1123+
DummyTracker([self, *args, *kwargs.values()]),
11221124
)
11231125

11241126
if not hasattr(self.value.__init__, "__code__"):
@@ -1138,7 +1140,7 @@ def call_function(self, /, *args, **kwargs):
11381140
new_object_variable = VariableFactory.from_value(
11391141
new_object,
11401142
self.graph,
1141-
DummyTracker([self, *list(args), *list(kwargs.values())]),
1143+
DummyTracker([self, *args, *kwargs.values()]),
11421144
)
11431145
fn_var(new_object_variable, *args, **kwargs)
11441146
return new_object_variable
@@ -1234,3 +1236,48 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
12341236
if is_namedtuple_class(value):
12351237
return NamedTupleClassVariable(value, graph, tracker)
12361238
return None
1239+
1240+
1241+
class PartialVariable(CallableVariable):
1242+
def __init__(
1243+
self,
1244+
value: partial,
1245+
graph: FunctionGraph,
1246+
tracker: Tracker,
1247+
):
1248+
super().__init__(graph, tracker)
1249+
self.value = value
1250+
1251+
def get_py_value(self, allow_tensor=False):
1252+
return self.value
1253+
1254+
def get_py_type(self):
1255+
return partial
1256+
1257+
def call_function(self, /, *call_args, **call_kwargs):
1258+
func_variable = VariableFactory.from_value(
1259+
self.value.func, self.graph, GetAttrTracker(self, "func")
1260+
)
1261+
partial_args = VariableFactory.from_value(
1262+
self.value.args, self.graph, GetAttrTracker(self, "args")
1263+
)
1264+
partial_keywords = VariableFactory.from_value(
1265+
self.value.keywords, self.graph, GetAttrTracker(self, "keywords")
1266+
)
1267+
assert isinstance(func_variable, CallableVariable)
1268+
1269+
partial_keywords.get_wrapped_items().update(call_kwargs)
1270+
1271+
out = func_variable(
1272+
*partial_args.get_wrapped_items(),
1273+
*call_args,
1274+
**(partial_keywords.get_wrapped_items() | call_kwargs),
1275+
)
1276+
1277+
return out
1278+
1279+
@VariableFactory.register_from_value()
1280+
def from_value(value: partial, graph: FunctionGraph, tracker: Tracker):
1281+
if isinstance(value, partial):
1282+
return PartialVariable(value, graph, tracker)
1283+
return None

test/sot/test_functools.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import operator
1717
import unittest
1818

19-
from test_case_base import TestCaseBase
19+
from test_case_base import (
20+
TestCaseBase,
21+
)
2022

2123
import paddle
2224
from paddle.jit.sot.psdb import check_no_breakgraph
@@ -54,6 +56,54 @@ def try_reduce_iter(fn, var, init=None):
5456
return ans
5557

5658

59+
def add(a, b, c, d=2, e=3, f=4):
60+
return a, b, c, d, e, f
61+
62+
63+
@check_no_breakgraph
64+
def simple_partial(x=1):
65+
partial_func = functools.partial(add, x)
66+
out = partial_func(2, 3)
67+
return out
68+
69+
70+
@check_no_breakgraph
71+
def simple_partial_with_two_args(x=1):
72+
partial_func = functools.partial(add, x, 2)
73+
out = partial_func(3)
74+
return out
75+
76+
77+
@check_no_breakgraph
78+
def simple_partial_with_n_args(x=1):
79+
partial_func = functools.partial(add, x, 2, 3, 4, 5, 6)
80+
out = partial_func()
81+
return out
82+
83+
84+
@check_no_breakgraph
85+
def simple_partial_with_n_args_kwargs():
86+
partial_func = functools.partial(add, 1, 2, d=2)
87+
out = partial_func(paddle.to_tensor(3), e=7)
88+
return out
89+
90+
91+
@check_no_breakgraph
92+
def simple_partial_with_n_args_same_kwargs():
93+
partial_func = functools.partial(add, 1, 2, e=2)
94+
out = partial_func(3, e=7)
95+
return out
96+
97+
98+
# NOTE(DrRyanHuang): Currently, SOT does not support tensor bind yet
99+
# because this case is not common enough.
100+
# We could create a PartialClassVariable to prevent breakgraph.
101+
def simple_partial_with_tensor_bind():
102+
partial_func = functools.partial(add, 1, paddle.to_tensor(2.0), e=2)
103+
out = partial_func(3, e=7)
104+
return out
105+
106+
57107
@check_no_breakgraph
58108
def try_reduce_iter_failed(fn, var):
59109
it = iter(var)
@@ -97,6 +147,25 @@ def test_reduce_with_init_value(self):
97147
def test_reduce_with_builtin_fn(self):
98148
self.assert_results(try_reduce, operator.add, [2, 5, 8])
99149

150+
def test_simple_partial(self):
151+
self.assert_results(simple_partial, 1)
152+
self.assert_results(simple_partial, 2)
153+
154+
def test_simple_partial_with_two_argsn(self):
155+
self.assert_results(simple_partial_with_two_args)
156+
157+
def test_simple_partial_with_n_args(self):
158+
self.assert_results(simple_partial_with_n_args)
159+
160+
def test_simple_partial_with_n_args_kwargs(self):
161+
self.assert_results(simple_partial_with_n_args_kwargs)
162+
163+
def test_simple_partial_with_n_args_same_kwargs(self):
164+
self.assert_results(simple_partial_with_n_args_same_kwargs)
165+
166+
def test_simple_partial_with_tensor_bind(self):
167+
self.assert_results(simple_partial_with_tensor_bind)
168+
100169

101170
if __name__ == "__main__":
102171
unittest.main()

0 commit comments

Comments
 (0)