Skip to content

Commit 2b1efb2

Browse files
authored
Update serialization (#10584)
* Update serialization.py * fix
1 parent 91920a6 commit 2b1efb2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/utils/serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
from _io import BufferedReader
2626
from safetensors import deserialize
2727

28-
from paddlenlp.transformers.utils import device_guard
2928
from paddlenlp.utils.env import PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
3029

3130
MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30
3231

32+
3333
_TYPES = {
3434
"F64": np.float64,
3535
"F32": np.float32,
@@ -206,6 +206,8 @@ def dumpy(*args, **kwarsg):
206206

207207

208208
def load_torch(path: str, **pickle_load_args):
209+
from paddlenlp.transformers.utils import device_guard
210+
209211
if path.endswith(PYTORCH_WEIGHTS_NAME) or os.path.split(path)[-1].startswith("pytorch_model-"):
210212
import torch
211213

0 commit comments

Comments
 (0)