We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 91920a6 commit 2b1efb2Copy full SHA for 2b1efb2
paddlenlp/utils/serialization.py
@@ -25,11 +25,11 @@
25
from _io import BufferedReader
26
from safetensors import deserialize
27
28
-from paddlenlp.transformers.utils import device_guard
29
from paddlenlp.utils.env import PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
30
31
MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30
32
+
33
_TYPES = {
34
"F64": np.float64,
35
"F32": np.float32,
@@ -206,6 +206,8 @@ def dumpy(*args, **kwarsg):
206
207
208
def load_torch(path: str, **pickle_load_args):
209
+ from paddlenlp.transformers.utils import device_guard
210
211
if path.endswith(PYTORCH_WEIGHTS_NAME) or os.path.split(path)[-1].startswith("pytorch_model-"):
212
import torch
213
0 commit comments