Skip to content

Conversation

albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Oct 16, 2025

Replace unittest skipTest from transformers (implicitly used by requirement markers) with pytest.skip markers:

  • require_liger_kernel
  • require_flash_attn
  • require_torch_accelerator
  • require_torch_multi_accelerator
  • require_wandb

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 88 to 99
# Function ported from transformers.testing_utils
def require_flash_attn():
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel

get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False

return pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Function ported from transformers.testing_utils
def require_flash_attn():
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel
get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False
return pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention")
# Function ported from transformers.testing_utils
def require_flash_attn(func=None):
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel
get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False
marker = pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention")
if func is not None:
return marker(func)
return marker

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other wise we have to change our decorators to @require_flash_attn()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your review, @kashif! 🤗

Finally, I aligned the implementation with the rest of the "require" pytest-markers, by calling an auxiliary function is_x_available.

Comment on lines +88 to +103
def is_flash_attn_available():
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel

get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False

return kernels_available or flash_attn_available


# Function ported from transformers.testing_utils
require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this one it 100% correct, because if kernels is available, but not flash-attn, and you try to use atto_implementation="flash_attention2", it will most likely fail.

Copy link
Member Author

@albertvillanova albertvillanova Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ported it from transformers. 😅
https://github.com/huggingface/transformers/blob/12a50f294d50e3d0e124511f2b6f43625f73ffce/src/transformers/testing_utils.py#L575-L591

def require_flash_attn(test_case):
    flash_attn_available = is_flash_attn_2_available()
    kernels_available = is_kernels_available()
    try:
        from kernels import get_kernel

        get_kernel("kernels-community/flash-attn")
    except Exception as _:
        kernels_available = False

    return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)

Do you think I made an error in porting it? @qgallouedec

Copy link
Member Author

@albertvillanova albertvillanova Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic on transformers side was recently changed by this PR:

Modified require_flash_attn in testing_utils.py to allow tests to run if either FlashAttention2 or the community kernel is available, broadening test coverage and reliability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see.
A lot of our tests rely on flash-attn lib; it's probably a good time to drop flash-attn and rely only on kernels:

  • replace all attn_implementation="flash_attention_2" -> attn_implementation="kernels-community/flash-attn"
  • replace require_flash_attn by require_kernels

@albertvillanova we can do this in a future PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! I created an issue for that:

@albertvillanova albertvillanova merged commit bfd6f49 into huggingface:main Oct 22, 2025
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants