18
18
# hf related
19
19
20
20
from datasets import load_dataset
21
+ from diffusers .models import AutoencoderKL
22
+
23
+ vae = AutoencoderKL .from_pretrained ("CompVis/stable-diffusion-v1-4" , subfolder = "vae" )
24
+
25
+ class Encoder (Module ):
26
+ def __init__ (self , vae ):
27
+ super ().__init__ ()
28
+ self .vae = vae
29
+
30
+ def forward (self , image ):
31
+ with torch .no_grad ():
32
+ latent = self .vae .encode (image * 2 - 1 )
33
+
34
+ return 0.18215 * latent .latent_dist .sample ()
35
+
36
+ class Decoder (Module ):
37
+ def __init__ (self , vae ):
38
+ super ().__init__ ()
39
+ self .vae = vae
40
+
41
+ def forward (self , latents ):
42
+ latents = (1 / 0.18215 ) * latents
43
+
44
+ with torch .no_grad ():
45
+ image = self .vae .decode (latents ).sample
46
+
47
+ return (image / 2 + 0.5 ).clamp (0 , 1 )
21
48
22
49
# results folder
23
50
@@ -36,23 +63,13 @@ def divisible_by(num, den):
36
63
37
64
# encoder / decoder
38
65
39
- class Encoder (Module ):
40
- def forward (self , x ):
41
- x = rearrange (x , '... c (h p1) (w p2) -> ... h w (p1 p2 c)' , p1 = 4 , p2 = 4 )
42
- return x * 2 - 1
43
-
44
- class Decoder (Module ):
45
- def forward (self , x ):
46
- x = rearrange (x , '... h w (p1 p2 c) -> ... c (h p1) (w p2)' , p1 = 4 , p2 = 4 , c = 3 )
47
- return ((x + 1 ) * 0.5 ).clamp (min = 0. , max = 1. )
48
-
49
66
model = Transfusion (
50
67
num_text_tokens = 10 ,
51
- dim_latent = 4 * 4 * 3 ,
52
- channel_first_latent = False ,
53
- modality_default_shape = (16 , 16 ),
54
- modality_encoder = Encoder (),
55
- modality_decoder = Decoder (),
68
+ dim_latent = 4 ,
69
+ channel_first_latent = True ,
70
+ modality_default_shape = (32 , 32 ),
71
+ modality_encoder = Encoder (vae ),
72
+ modality_decoder = Decoder (vae ),
56
73
add_pos_emb = True ,
57
74
modality_num_dim = 2 ,
58
75
velocity_consistency_loss_weight = 0.1 ,
@@ -65,28 +82,34 @@ def forward(self, x):
65
82
)
66
83
).cuda ()
67
84
68
- ema_model = model .create_ema ()
85
+ ema_model = model .create_ema (0.9 )
69
86
70
87
class FlowersDataset (Dataset ):
71
- def __init__ (self ):
88
+ def __init__ (self , image_size ):
72
89
self .ds = load_dataset ("nelorth/oxford-flowers" )['train' ]
73
90
91
+ self .transform = T .Compose ([
92
+ T .Resize ((image_size , image_size )),
93
+ T .PILToTensor ()
94
+ ])
95
+
74
96
def __len__ (self ):
75
97
return len (self .ds )
76
98
77
99
def __getitem__ (self , idx ):
78
100
pil = self .ds [idx ]['image' ]
79
- image_tensor = T . PILToTensor () (pil )
80
- return T . Resize (( 64 , 64 ))( image_tensor / 255. )
101
+ tensor = self . transform (pil )
102
+ return tensor / 255.
81
103
82
104
def cycle (iter_dl ):
83
105
while True :
84
106
for batch in iter_dl :
85
107
yield batch
86
108
87
- dataset = FlowersDataset ()
109
+ dataset = FlowersDataset (256 )
110
+
111
+ dataloader = DataLoader (dataset , batch_size = 4 , shuffle = True )
88
112
89
- dataloader = DataLoader (dataset , batch_size = 32 , shuffle = True )
90
113
iter_dl = cycle (dataloader )
91
114
92
115
optimizer = Adam (model .parameters (), lr = 8e-4 )
@@ -95,8 +118,9 @@ def cycle(iter_dl):
95
118
96
119
for step in range (1 , 100_000 + 1 ):
97
120
98
- loss = model .forward_modality (next (iter_dl ))
99
- loss .backward ()
121
+ for _ in range (4 ):
122
+ loss = model .forward_modality (next (iter_dl ))
123
+ (loss / 4 ).backward ()
100
124
101
125
torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
102
126
@@ -108,9 +132,9 @@ def cycle(iter_dl):
108
132
print (f'{ step } : { loss .item ():.3f} ' )
109
133
110
134
if divisible_by (step , SAMPLE_EVERY ):
111
- image = ema_model .generate_modality_only (batch_size = 64 )
135
+ image = ema_model .generate_modality_only (batch_size = 4 )
112
136
113
137
save_image (
114
- rearrange (image , '(gh gw) c h w -> c (gh h) (gw w)' , gh = 8 ).detach ().cpu (),
138
+ rearrange (image , '(gh gw) c h w -> c (gh h) (gw w)' , gh = 2 ).detach ().cpu (),
115
139
str (results_folder / f'{ step } .png' )
116
140
)
0 commit comments