Skip to content

Commit 8c1e5bc

Browse files
author
Laurent
committed
remove compute_clip_image_embedding overloads in image_prompt.py + improve docstring slightly
1 parent 1fc2ad3 commit 8c1e5bc

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

src/refiners/foundationals/latent_diffusion/image_prompt.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
2+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
33

44
import torch
55
from jaxtyping import Float
@@ -454,37 +454,28 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
454454
"""
455455
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
456456

457-
@overload
458-
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ...
459-
460-
@overload
461-
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ...
462-
463-
@overload
464-
def compute_clip_image_embedding(
465-
self, image_prompt: list[Image.Image], weights: list[float] | None = None
466-
) -> Tensor: ...
467-
468457
def compute_clip_image_embedding(
469458
self,
470459
image_prompt: Tensor | Image.Image | list[Image.Image],
471460
weights: list[float] | None = None,
472461
concat_batches: bool = True,
473462
) -> Tensor:
474-
"""Compute the CLIP image embedding.
463+
"""Compute CLIP image embedding(s).
475464
476465
Args:
477-
image_prompt: The image prompt to use.
478-
weights: The scale to use for the image prompt.
479-
concat_batches: Whether to concatenate the batches.
466+
image_prompt: The image prompt(s) to use.
467+
weights: The scale(s) to use for the image prompt(s).
468+
concat_batches: Whether to concatenate the image embeddings along the feature dimension.
480469
481470
Returns:
482-
The CLIP image embedding.
471+
The CLIP image embedding(s).
483472
"""
484473
if isinstance(image_prompt, Image.Image):
485474
image_prompt = self.preprocess_image(image_prompt)
486475
elif isinstance(image_prompt, list):
487-
assert all(isinstance(image, Image.Image) for image in image_prompt)
476+
assert all(
477+
isinstance(image, Image.Image) for image in image_prompt
478+
), "All elements of `image_prompt` must be of type `Image.Image`"
488479
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
489480

490481
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)

0 commit comments

Comments
 (0)