1
1
import math
2
2
from typing import TYPE_CHECKING , Any , Generic , TypeVar , overload
3
3
4
+ import torch
4
5
from jaxtyping import Float
5
6
from PIL import Image
6
- from torch import Tensor , cat , device as Device , dtype as DType , nn , softmax , tensor , zeros_like
7
+ from torch import Tensor , device as Device , dtype as DType , nn
7
8
8
9
import refiners .fluxion .layers as fl
9
10
from refiners .fluxion .adapters .adapter import Adapter
@@ -98,7 +99,7 @@ def forward(
98
99
v = self .reshape_tensor (value )
99
100
100
101
attention = (q * self .scale ) @ (k * self .scale ).transpose (- 2 , - 1 )
101
- attention = softmax (input = attention .float (), dim = - 1 ).type (attention .dtype )
102
+ attention = torch . softmax (input = attention .float (), dim = - 1 ).type (attention .dtype )
102
103
attention = attention @ v
103
104
104
105
return attention .permute (0 , 2 , 1 , 3 ).reshape (bs , length , - 1 )
@@ -159,7 +160,7 @@ def __init__(
159
160
)
160
161
161
162
def to_kv (self , x : Tensor , latents : Tensor ) -> Tensor :
162
- return cat ((x , latents ), dim = - 2 )
163
+ return torch . cat ((x , latents ), dim = - 2 )
163
164
164
165
165
166
class LatentsToken (fl .Chain ):
@@ -484,7 +485,7 @@ def compute_clip_image_embedding(
484
485
image_prompt = self .preprocess_image (image_prompt )
485
486
elif isinstance (image_prompt , list ):
486
487
assert all (isinstance (image , Image .Image ) for image in image_prompt )
487
- image_prompt = cat ([self .preprocess_image (image ) for image in image_prompt ])
488
+ image_prompt = torch . cat ([self .preprocess_image (image ) for image in image_prompt ])
488
489
489
490
negative_embedding , conditional_embedding = self ._compute_clip_image_embedding (image_prompt )
490
491
@@ -493,28 +494,28 @@ def compute_clip_image_embedding(
493
494
assert len (weights ) == batch_size , f"Got { len (weights )} weights for { batch_size } images"
494
495
if any (weight != 1.0 for weight in weights ):
495
496
conditional_embedding *= (
496
- tensor (weights , device = conditional_embedding .device , dtype = conditional_embedding .dtype )
497
+ torch . tensor (weights , device = conditional_embedding .device , dtype = conditional_embedding .dtype )
497
498
.unsqueeze (- 1 )
498
499
.unsqueeze (- 1 )
499
500
)
500
501
501
502
if batch_size > 1 and concat_batches :
502
503
# Create a longer image tokens sequence when a batch of images is given
503
504
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
504
- negative_embedding = cat (negative_embedding .chunk (batch_size ), dim = 1 )
505
- conditional_embedding = cat (conditional_embedding .chunk (batch_size ), dim = 1 )
505
+ negative_embedding = torch . cat (negative_embedding .chunk (batch_size ), dim = 1 )
506
+ conditional_embedding = torch . cat (conditional_embedding .chunk (batch_size ), dim = 1 )
506
507
507
- return cat ((negative_embedding , conditional_embedding ))
508
+ return torch . cat ((negative_embedding , conditional_embedding ))
508
509
509
510
def _compute_clip_image_embedding (self , image_prompt : Tensor ) -> tuple [Tensor , Tensor ]:
510
511
image_encoder = self .clip_image_encoder if not self .fine_grained else self .grid_image_encoder
511
512
clip_embedding = image_encoder (image_prompt )
512
513
conditional_embedding = self .image_proj (clip_embedding )
513
514
if not self .fine_grained :
514
- negative_embedding = self .image_proj (zeros_like (clip_embedding ))
515
+ negative_embedding = self .image_proj (torch . zeros_like (clip_embedding ))
515
516
else :
516
517
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
517
- clip_embedding = image_encoder (zeros_like (image_prompt ))
518
+ clip_embedding = image_encoder (torch . zeros_like (image_prompt ))
518
519
negative_embedding = self .image_proj (clip_embedding )
519
520
return negative_embedding , conditional_embedding
520
521
0 commit comments