|
1 | 1 | import math
|
2 |
| -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload |
| 2 | +from typing import TYPE_CHECKING, Any, Generic, TypeVar |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from jaxtyping import Float
|
@@ -454,37 +454,28 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
454 | 454 | """
|
455 | 455 | self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
456 | 456 |
|
457 |
| - @overload |
458 |
| - def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ... |
459 |
| - |
460 |
| - @overload |
461 |
| - def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ... |
462 |
| - |
463 |
| - @overload |
464 |
| - def compute_clip_image_embedding( |
465 |
| - self, image_prompt: list[Image.Image], weights: list[float] | None = None |
466 |
| - ) -> Tensor: ... |
467 |
| - |
468 | 457 | def compute_clip_image_embedding(
|
469 | 458 | self,
|
470 | 459 | image_prompt: Tensor | Image.Image | list[Image.Image],
|
471 | 460 | weights: list[float] | None = None,
|
472 | 461 | concat_batches: bool = True,
|
473 | 462 | ) -> Tensor:
|
474 |
| - """Compute the CLIP image embedding. |
| 463 | + """Compute CLIP image embedding(s). |
475 | 464 |
|
476 | 465 | Args:
|
477 |
| - image_prompt: The image prompt to use. |
478 |
| - weights: The scale to use for the image prompt. |
479 |
| - concat_batches: Whether to concatenate the batches. |
| 466 | + image_prompt: The image prompt(s) to use. |
| 467 | + weights: The scale(s) to use for the image prompt(s). |
| 468 | + concat_batches: Whether to concatenate the image embeddings along the feature dimension. |
480 | 469 |
|
481 | 470 | Returns:
|
482 |
| - The CLIP image embedding. |
| 471 | + The CLIP image embedding(s). |
483 | 472 | """
|
484 | 473 | if isinstance(image_prompt, Image.Image):
|
485 | 474 | image_prompt = self.preprocess_image(image_prompt)
|
486 | 475 | elif isinstance(image_prompt, list):
|
487 |
| - assert all(isinstance(image, Image.Image) for image in image_prompt) |
| 476 | + assert all( |
| 477 | + isinstance(image, Image.Image) for image in image_prompt |
| 478 | + ), "All elements of `image_prompt` must be of type `Image.Image`" |
488 | 479 | image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
|
489 | 480 |
|
490 | 481 | negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
|
0 commit comments