@@ -294,13 +294,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
294
294
attention_batch_size = config .transformer_batch_size ,
295
295
normalization_config = config .transformer_norm_config ,
296
296
attention_config = attention_config ,
297
+ enable_hlfb = False ,
297
298
),
298
299
cross_attention_block_config = unet_cfg .CrossAttentionBlock2DConfig (
299
300
query_dim = output_channel ,
300
301
cross_dim = config .transformer_cross_attention_dim ,
301
302
attention_batch_size = config .transformer_batch_size ,
302
303
normalization_config = config .transformer_norm_config ,
303
304
attention_config = attention_config ,
305
+ enable_hlfb = False ,
304
306
),
305
307
pre_conv_normalization_config = config .transformer_pre_conv_norm_config ,
306
308
feed_forward_block_config = unet_cfg .FeedForwardBlock2DConfig (
@@ -354,13 +356,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
354
356
attention_batch_size = config .transformer_batch_size ,
355
357
normalization_config = config .transformer_norm_config ,
356
358
attention_config = attention_config ,
359
+ enable_hlfb = False ,
357
360
),
358
361
cross_attention_block_config = unet_cfg .CrossAttentionBlock2DConfig (
359
362
query_dim = mid_block_channels ,
360
363
cross_dim = config .transformer_cross_attention_dim ,
361
364
attention_batch_size = config .transformer_batch_size ,
362
365
normalization_config = config .transformer_norm_config ,
363
366
attention_config = attention_config ,
367
+ enable_hlfb = False ,
364
368
),
365
369
pre_conv_normalization_config = config .transformer_pre_conv_norm_config ,
366
370
feed_forward_block_config = unet_cfg .FeedForwardBlock2DConfig (
@@ -415,13 +419,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
415
419
attention_batch_size = config .transformer_batch_size ,
416
420
normalization_config = config .transformer_norm_config ,
417
421
attention_config = attention_config ,
422
+ enable_hlfb = False ,
418
423
),
419
424
cross_attention_block_config = unet_cfg .CrossAttentionBlock2DConfig (
420
425
query_dim = output_channel ,
421
426
cross_dim = config .transformer_cross_attention_dim ,
422
427
attention_batch_size = config .transformer_batch_size ,
423
428
normalization_config = config .transformer_norm_config ,
424
429
attention_config = attention_config ,
430
+ enable_hlfb = False ,
425
431
),
426
432
pre_conv_normalization_config = config .transformer_pre_conv_norm_config ,
427
433
feed_forward_block_config = unet_cfg .FeedForwardBlock2DConfig (
0 commit comments