|
1 | 1 | import warnings
|
2 | 2 | from pathlib import Path
|
3 |
| -from typing import Any, Iterable, Literal, TypeVar, cast |
| 3 | +from typing import Any, Iterable, TypeVar, cast |
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | from jaxtyping import Float
|
7 | 7 | from numpy import array, float32
|
8 | 8 | from PIL import Image
|
9 |
| -from safetensors import safe_open as _safe_open # type: ignore |
10 |
| -from safetensors.torch import save_file as _save_file # type: ignore |
| 9 | +from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore |
11 | 10 | from torch import Tensor, device as Device, dtype as DType
|
12 | 11 | from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
13 | 12 |
|
@@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
|
186 | 185 | return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
187 | 186 |
|
188 | 187 |
|
189 |
| -def safe_open( |
190 |
| - path: Path | str, |
191 |
| - framework: Literal["pytorch", "tensorflow", "flax", "numpy"], |
192 |
| - device: Device | str = "cpu", |
193 |
| -) -> dict[str, Tensor]: |
194 |
| - """Open a SafeTensor file from disk. |
195 |
| -
|
196 |
| - Args: |
197 |
| - path: The path to the file. |
198 |
| - framework: The framework used to save the file. |
199 |
| - device: The device to use for the tensors. |
200 |
| -
|
201 |
| - Returns: |
202 |
| - The loaded tensors. |
203 |
| - """ |
204 |
| - framework_mapping = { |
205 |
| - "pytorch": "pt", |
206 |
| - "tensorflow": "tf", |
207 |
| - "flax": "flax", |
208 |
| - "numpy": "numpy", |
209 |
| - } |
210 |
| - return _safe_open( |
211 |
| - str(path), |
212 |
| - framework=framework_mapping[framework], |
213 |
| - device=str(device), |
214 |
| - ) # type: ignore |
215 |
| - |
216 |
| - |
217 | 188 | def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
218 | 189 | """Load tensors from a file saved with `torch.save` from disk.
|
219 | 190 |
|
@@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
|
247 | 218 | Returns:
|
248 | 219 | The loaded tensors.
|
249 | 220 | """
|
250 |
| - with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore |
251 |
| - return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore |
| 221 | + return _load_file(path, str(device)) |
252 | 222 |
|
253 | 223 |
|
254 | 224 | def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
|
0 commit comments