Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit f289548

Browse files
committed
simplify implementation of load_from_safetensors
1 parent 444882a commit f289548

File tree

3 files changed

+7
-35
lines changed

3 files changed

+7
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
66
license = { text = "MIT License" }
77
dependencies = [
88
"torch>=2.1.1",
9-
"safetensors>=0.4.0",
9+
"safetensors>=0.4.5",
1010
"pillow>=10.4.0",
1111
"jaxtyping>=0.2.23",
1212
"packaging>=23.2",

requirements.lock

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ rpds-py==0.19.1
338338
# via referencing
339339
s3transfer==0.10.2
340340
# via boto3
341-
safetensors==0.4.3
341+
safetensors==0.4.5
342342
# via diffusers
343343
# via refiners
344344
# via timm
@@ -347,6 +347,8 @@ segment-anything-hq==0.3
347347
# via refiners
348348
segment-anything-py==1.0.1
349349
# via refiners
350+
sentencepiece==0.2.0
351+
# via refiners
350352
sentry-sdk==2.12.0
351353
# via wandb
352354
setproctitle==1.3.3

src/refiners/fluxion/utils.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import warnings
22
from pathlib import Path
3-
from typing import Any, Iterable, Literal, TypeVar, cast
3+
from typing import Any, Iterable, TypeVar, cast
44

55
import torch
66
from jaxtyping import Float
77
from numpy import array, float32
88
from PIL import Image
9-
from safetensors import safe_open as _safe_open # type: ignore
10-
from safetensors.torch import save_file as _save_file # type: ignore
9+
from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore
1110
from torch import Tensor, device as Device, dtype as DType
1211
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
1312

@@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
186185
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
187186

188187

189-
def safe_open(
190-
path: Path | str,
191-
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
192-
device: Device | str = "cpu",
193-
) -> dict[str, Tensor]:
194-
"""Open a SafeTensor file from disk.
195-
196-
Args:
197-
path: The path to the file.
198-
framework: The framework used to save the file.
199-
device: The device to use for the tensors.
200-
201-
Returns:
202-
The loaded tensors.
203-
"""
204-
framework_mapping = {
205-
"pytorch": "pt",
206-
"tensorflow": "tf",
207-
"flax": "flax",
208-
"numpy": "numpy",
209-
}
210-
return _safe_open(
211-
str(path),
212-
framework=framework_mapping[framework],
213-
device=str(device),
214-
) # type: ignore
215-
216-
217188
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
218189
"""Load tensors from a file saved with `torch.save` from disk.
219190
@@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
247218
Returns:
248219
The loaded tensors.
249220
"""
250-
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
251-
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
221+
return _load_file(path, str(device))
252222

253223

254224
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:

0 commit comments

Comments
 (0)