Skip to content

Commit dabc13f

Browse files
authored
[LLM] support disable monkey patch (#10617)
1 parent 2c02199 commit dabc13f

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

paddlenlp/utils/paddle_patch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
import ml_dtypes
1618
import numpy as np
1719
import paddle
1820
import paddle.nn as nn
1921

22+
enable_paddlenlp_monkey_patch = int(os.getenv("FLAGS_enable_paddlenlp_monkey_patch", "1"))
23+
2024
np.bfloat16 = ml_dtypes.bfloat16
2125
np.float8_e5m2 = ml_dtypes.float8_e5m2
2226
np.float8_e4m3fn = ml_dtypes.float8_e4m3fn
@@ -141,8 +145,9 @@ def _numel(self, *args, **kwargs):
141145
return origin_numel(self, *args, **kwargs)
142146

143147

144-
paddle.Tensor.numpy = _numpy
145-
paddle.Tensor.__call__ = enhance_init
146-
paddle.to_tensor = enhance_to_tensor
147-
paddle.core.eager.Tensor.set_value = enhance_set_value
148-
paddle.Tensor.numel = _numel
148+
if enable_paddlenlp_monkey_patch:
149+
paddle.Tensor.numpy = _numpy
150+
paddle.Tensor.__call__ = enhance_init
151+
paddle.to_tensor = enhance_to_tensor
152+
paddle.core.eager.Tensor.set_value = enhance_set_value
153+
paddle.Tensor.numel = _numel

0 commit comments

Comments
 (0)