Skip to content

Commit c054501

Browse files
feat(mm): llava model loading supports partial loading; fix OOM crash on initial load
The model manager has two types of model cache entries: - `CachedModelOnlyFullLoad`: The model may only ever be loaded and unloaded as a single object. - `CachedModelWithPartialLoad`: The model may be partially loaded and unloaded. Partial loaded is enabled by overwriting certain torch layer classes, adding the ability to autocast the layer to a device on-the-fly. See `CustomLinear` for an example. So, to take advantage of partial loading and be cached as a `CachedModelWithPartialLoad`, the model must inherit from `torch.nn.Module`. The LLaVA classes provided by `transformers` do inherit from `torch.nn.Module`, but we wrap those classes in a separate class called `LlavaOnevisionModel`. The wrapper encapsulate both the LLaVA model and its "processor" - a lightweight class that prepares model inputs like text and images. While it is more elegant to encapsulate both model and processor classes in a single entity, this prevents the model cache from enabling partial loading for the chunky vLLM model. Fixing this involved a few changes. - Update the `LlavaOnevisionModelLoader` class to operate on the vLLM model directly, instead the `LlavaOnevisionModel` wrapper class. - Instantiate the processor directly in the node. The processor is lightweight and does its business on the CPU. We don't need to worry about caching in the model manager. - Remove caching support code from the `LlavaOnevisionModel` wrapper class. It's not needed, because we do not cache this class. The class now only handles running the models provided to it. - Rename `LlavaOnevisionModel` to `LlavaOnevisionPipeline` to better represent its purpose. These changes have a bonus effect of fixing an OOM crash when initially loading the models. This was most apparent when loading LLaVA 7B, which is pretty chunky. The initial load is onto CPU RAM. In the old version of the loaders, we ignored the loader's target dtype for the initial load. Instead, we loaded the model at `transformers`'s "default" dtype of fp32. LLaVA 7B is fp16 and weighs ~17GB. Loading as fp32 means we need double that amount (~34GB) of CPU RAM. Many users only have 32GB RAM, so this causes a _CPU_ OOM - which is a hard crash of the whole process. With the updated loaders, the initial load logic now uses the target dtype for the initial load. LLaVA now needs the expected ~17GB RAM for its initial load. PS: If we didn't make the accompanying partial loading changes, we still could have solved this OOM. We'd just need to pass the initial load dtype to the wrapper class and have it load on that dtype. But we may as well fix both issues. PPS: There are other models whose model classes are wrappers around a torch module class, and thus cannot be partially loaded. However, these models are typically fairly small and/or are run only on their own, so they don't benefit as much from partial loading. It's the really big models (like LLaVA 7B) that benefit most from the partial loading.
1 parent c1d819c commit c054501

File tree

3 files changed

+21
-31
lines changed

3 files changed

+21
-31
lines changed

invokeai/app/invocations/llava_onevision_vllm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import torch
44
from PIL.Image import Image
55
from pydantic import field_validator
6+
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
67

78
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
89
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
910
from invokeai.app.invocations.model import ModelIdentifierField
1011
from invokeai.app.invocations.primitives import StringOutput
1112
from invokeai.app.services.shared.invocation_context import InvocationContext
12-
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
13+
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
1314
from invokeai.backend.util.devices import TorchDevice
1415

1516

@@ -54,10 +55,17 @@ def _get_images(self, context: InvocationContext) -> list[Image]:
5455
@torch.no_grad()
5556
def invoke(self, context: InvocationContext) -> StringOutput:
5657
images = self._get_images(context)
58+
model_config = context.models.get_config(self.vllm_model)
5759

58-
with context.models.load(self.vllm_model) as vllm_model:
59-
assert isinstance(vllm_model, LlavaOnevisionModel)
60-
output = vllm_model.run(
60+
with context.models.load(self.vllm_model).model_on_device() as (_, model):
61+
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
62+
63+
model_abs_path = context.models.get_absolute_path(model_config)
64+
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
65+
assert isinstance(processor, LlavaOnevisionProcessor)
66+
67+
model = LlavaOnevisionPipeline(model, processor)
68+
output = model.run(
6169
prompt=self.prompt,
6270
images=images,
6371
device=TorchDevice.choose_torch_device(),
Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,15 @@
1-
from pathlib import Path
2-
from typing import Optional
3-
41
import torch
52
from PIL.Image import Image
6-
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
3+
from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
74

8-
from invokeai.backend.raw_model import RawModel
95

6+
class LlavaOnevisionPipeline:
7+
"""A wrapper for a LLaVA Onevision model + processor."""
108

11-
class LlavaOnevisionModel(RawModel):
129
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
1310
self._vllm_model = vllm_model
1411
self._processor = processor
1512

16-
@classmethod
17-
def load_from_path(cls, path: str | Path):
18-
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
19-
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
20-
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
21-
assert isinstance(processor, LlavaOnevisionProcessor)
22-
return cls(vllm_model, processor)
23-
2413
def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
2514
# TODO(ryand): Tune the max number of images that are useful for the model.
2615
if len(images) > 3:
@@ -44,13 +33,3 @@ def run(self, prompt: str, images: list[Image], device: torch.device, dtype: tor
4433
# The output_str will include the prompt, so we extract the response.
4534
response = output_str.split("assistant\n", 1)[1].strip()
4635
return response
47-
48-
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
49-
self._vllm_model.to(device=device, dtype=dtype)
50-
51-
def calc_size(self) -> int:
52-
"""Get size of the model in memory in bytes."""
53-
# HACK(ryand): Fix this issue with circular imports.
54-
from invokeai.backend.model_manager.load.model_util import calc_module_size
55-
56-
return calc_module_size(self._vllm_model)

invokeai/backend/model_manager/load/model_loaders/llava_onevision.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from pathlib import Path
22
from typing import Optional
33

4-
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
4+
from transformers import LlavaOnevisionForConditionalGeneration
5+
56
from invokeai.backend.model_manager.config import (
67
AnyModelConfig,
78
)
@@ -23,6 +24,8 @@ def _load_model(
2324
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
2425

2526
model_path = Path(config.path)
26-
model = LlavaOnevisionModel.load_from_path(model_path)
27-
model.to(dtype=self._torch_dtype)
27+
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
28+
model_path, local_files_only=True, torch_dtype=self._torch_dtype
29+
)
30+
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
2831
return model

0 commit comments

Comments
 (0)