Skip to content

Commit bc355f7

Browse files
committed
add _compile = True to create_block_mask
1 parent 55e3bbe commit bc355f7

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.9.4"
3+
version = "0.9.5"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def forward(
984984
if needs_masking:
985985
if causal_mask:
986986
if should_use_flex_attn:
987-
block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
987+
block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device)
988988
attn_mask_kwargs.update(block_mask = block_mask)
989989
else:
990990
attn_mask_kwargs.update(causal = True)
@@ -994,7 +994,7 @@ def forward(
994994

995995
if should_use_flex_attn:
996996
transfusion_mask_fn = transfusion_attn_mask(modality_positions)
997-
block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
997+
block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device)
998998
attn_mask_kwargs.update(block_mask = block_mask)
999999
else:
10001000
attn_mask = naive_attn_mask(seq_len, modality_positions, device = device)

0 commit comments

Comments
 (0)