Skip to content

Commit 814406d

Browse files
feat(mm): siglip model loading supports partial loading
In the previous commit, the LLaVA model was updated to support partial loading. In this commit, the SigLIP model is updated in the same way. This model is used for FLUX Redux. It's <4GB and only ever run in isolation, so it won't benefit from partial loading for the vast majority of users. Regardless, I think it is best if we make _all_ models work with partial loading. PS: I also fixed the initial load dtype issue, described in the prev commit. It's probably a non-issue for this model, but we may as well fix it.
1 parent c054501 commit 814406d

File tree

4 files changed

+13
-33
lines changed

4 files changed

+13
-33
lines changed

invokeai/app/invocations/flux_redux.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from PIL import Image
6+
from transformers import SiglipImageProcessor, SiglipVisionModel
67

78
from invokeai.app.invocations.baseinvocation import (
89
BaseInvocation,
@@ -115,8 +116,14 @@ def _downsample_weight(self, context: InvocationContext, redux_conditioning: tor
115116
@torch.no_grad()
116117
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
117118
siglip_model_config = self._get_siglip_model(context)
118-
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
119-
assert isinstance(siglip_pipeline, SigLipPipeline)
119+
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
120+
assert isinstance(model, SiglipVisionModel)
121+
122+
model_abs_path = context.models.get_absolute_path(siglip_model_config)
123+
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
124+
assert isinstance(processor, SiglipImageProcessor)
125+
126+
siglip_pipeline = SigLipPipeline(processor, model)
120127
return siglip_pipeline.encode_image(
121128
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
122129
)

invokeai/backend/model_manager/load/model_loaders/sig_lip_pipeline.py renamed to invokeai/backend/model_manager/load/model_loaders/sig_lip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from pathlib import Path
22
from typing import Optional
33

4+
from transformers import SiglipVisionModel
5+
46
from invokeai.backend.model_manager.config import (
57
AnyModelConfig,
68
)
79
from invokeai.backend.model_manager.load.load_default import ModelLoader
810
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
911
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
10-
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
1112

1213

1314
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.SigLIP, format=ModelFormat.Diffusers)
@@ -23,6 +24,5 @@ def _load_model(
2324
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
2425

2526
model_path = Path(config.path)
26-
model = SigLipPipeline.load_from_path(model_path)
27-
model.to(dtype=self._torch_dtype)
27+
model = SiglipVisionModel.from_pretrained(model_path, local_files_only=True, torch_dtype=self._torch_dtype)
2828
return model

invokeai/backend/model_manager/load/model_util.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
1717
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
1818
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
19-
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
2019
from invokeai.backend.model_manager.taxonomy import AnyModel
2120
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
2221
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
23-
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
2422
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
2523
from invokeai.backend.textual_inversion import TextualInversionModelRaw
2624
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -51,8 +49,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
5149
GroundingDinoPipeline,
5250
SegmentAnythingPipeline,
5351
DepthAnythingPipeline,
54-
SigLipPipeline,
55-
LlavaOnevisionModel,
5652
),
5753
):
5854
return model.calc_size()
Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
from pathlib import Path
2-
from typing import Optional
3-
41
import torch
52
from PIL import Image
63
from transformers import SiglipImageProcessor, SiglipVisionModel
74

8-
from invokeai.backend.raw_model import RawModel
9-
105

11-
class SigLipPipeline(RawModel):
6+
class SigLipPipeline:
127
"""A wrapper for a SigLIP model + processor."""
138

149
def __init__(
@@ -19,25 +14,7 @@ def __init__(
1914
self._siglip_processor = siglip_processor
2015
self._siglip_model = siglip_model
2116

22-
@classmethod
23-
def load_from_path(cls, path: str | Path):
24-
siglip_model = SiglipVisionModel.from_pretrained(path, local_files_only=True)
25-
assert isinstance(siglip_model, SiglipVisionModel)
26-
siglip_processor = SiglipImageProcessor.from_pretrained(path, local_files_only=True)
27-
assert isinstance(siglip_processor, SiglipImageProcessor)
28-
return cls(siglip_processor, siglip_model)
29-
30-
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
31-
self._siglip_model.to(device=device, dtype=dtype)
32-
3317
def encode_image(self, x: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
3418
imgs = self._siglip_processor.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
3519
encoded_x = self._siglip_model(**imgs.to(device=device, dtype=dtype)).last_hidden_state
3620
return encoded_x
37-
38-
def calc_size(self) -> int:
39-
"""Get size of the model in memory in bytes."""
40-
# HACK(ryand): Fix this issue with circular imports.
41-
from invokeai.backend.model_manager.load.model_util import calc_module_size
42-
43-
return calc_module_size(self._siglip_model)

0 commit comments

Comments
 (0)