Skip to content

Commit 4c40530

Browse files
authored
Temporarily disable HLFB for stable diffusion (#104)
* Tmp disable SDPA * Update
1 parent e79fdc0 commit 4c40530

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

ai_edge_torch/generative/examples/stable_diffusion/decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
293293
qkv_fused_interleaved=False,
294294
rotary_percentage=0.0,
295295
),
296+
enable_hlfb=False,
296297
)
297298

298299
mid_block_config = unet_cfg.MidBlock2DConfig(

ai_edge_torch/generative/examples/stable_diffusion/diffusion.py

+6
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
294294
attention_batch_size=config.transformer_batch_size,
295295
normalization_config=config.transformer_norm_config,
296296
attention_config=attention_config,
297+
enable_hlfb=False,
297298
),
298299
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
299300
query_dim=output_channel,
300301
cross_dim=config.transformer_cross_attention_dim,
301302
attention_batch_size=config.transformer_batch_size,
302303
normalization_config=config.transformer_norm_config,
303304
attention_config=attention_config,
305+
enable_hlfb=False,
304306
),
305307
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
306308
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -354,13 +356,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
354356
attention_batch_size=config.transformer_batch_size,
355357
normalization_config=config.transformer_norm_config,
356358
attention_config=attention_config,
359+
enable_hlfb=False,
357360
),
358361
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
359362
query_dim=mid_block_channels,
360363
cross_dim=config.transformer_cross_attention_dim,
361364
attention_batch_size=config.transformer_batch_size,
362365
normalization_config=config.transformer_norm_config,
363366
attention_config=attention_config,
367+
enable_hlfb=False,
364368
),
365369
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
366370
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -415,13 +419,15 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
415419
attention_batch_size=config.transformer_batch_size,
416420
normalization_config=config.transformer_norm_config,
417421
attention_config=attention_config,
422+
enable_hlfb=False,
418423
),
419424
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
420425
query_dim=output_channel,
421426
cross_dim=config.transformer_cross_attention_dim,
422427
attention_batch_size=config.transformer_batch_size,
423428
normalization_config=config.transformer_norm_config,
424429
attention_config=attention_config,
430+
enable_hlfb=False,
425431
),
426432
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
427433
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(

0 commit comments

Comments
 (0)