-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Replace unittest skipTest from transformers with pytest.skip #4297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace unittest skipTest from transformers with pytest.skip #4297
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
tests/testing_utils.py
Outdated
# Function ported from transformers.testing_utils | ||
def require_flash_attn(): | ||
flash_attn_available = is_flash_attn_2_available() | ||
kernels_available = is_kernels_available() | ||
try: | ||
from kernels import get_kernel | ||
|
||
get_kernel("kernels-community/flash-attn") | ||
except Exception: | ||
kernels_available = False | ||
|
||
return pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Function ported from transformers.testing_utils | |
def require_flash_attn(): | |
flash_attn_available = is_flash_attn_2_available() | |
kernels_available = is_kernels_available() | |
try: | |
from kernels import get_kernel | |
get_kernel("kernels-community/flash-attn") | |
except Exception: | |
kernels_available = False | |
return pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention") | |
# Function ported from transformers.testing_utils | |
def require_flash_attn(func=None): | |
flash_attn_available = is_flash_attn_2_available() | |
kernels_available = is_kernels_available() | |
try: | |
from kernels import get_kernel | |
get_kernel("kernels-community/flash-attn") | |
except Exception: | |
kernels_available = False | |
marker = pytest.mark.skipif(not (kernels_available or flash_attn_available), reason="test requires Flash Attention") | |
if func is not None: | |
return marker(func) | |
return marker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other wise we have to change our decorators to @require_flash_attn()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your review, @kashif! 🤗
Finally, I aligned the implementation with the rest of the "require" pytest-markers, by calling an auxiliary function is_x_available
.
def is_flash_attn_available(): | ||
flash_attn_available = is_flash_attn_2_available() | ||
kernels_available = is_kernels_available() | ||
try: | ||
from kernels import get_kernel | ||
|
||
get_kernel("kernels-community/flash-attn") | ||
except Exception: | ||
kernels_available = False | ||
|
||
return kernels_available or flash_attn_available | ||
|
||
|
||
# Function ported from transformers.testing_utils | ||
require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this one it 100% correct, because if kernels is available, but not flash-attn, and you try to use atto_implementation="flash_attention2"
, it will most likely fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just ported it from transformers. 😅
https://github.com/huggingface/transformers/blob/12a50f294d50e3d0e124511f2b6f43625f73ffce/src/transformers/testing_utils.py#L575-L591
def require_flash_attn(test_case):
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel
get_kernel("kernels-community/flash-attn")
except Exception as _:
kernels_available = False
return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)
Do you think I made an error in porting it? @qgallouedec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic on transformers
side was recently changed by this PR:
Modified
require_flash_attn
intesting_utils.py
to allow tests to run if either FlashAttention2 or the community kernel is available, broadening test coverage and reliability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I see.
A lot of our tests rely on flash-attn lib; it's probably a good time to drop flash-attn
and rely only on kernels
:
- replace all
attn_implementation="flash_attention_2"
->attn_implementation="kernels-community/flash-attn"
- replace
require_flash_attn
byrequire_kernels
@albertvillanova we can do this in a future PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! I created an issue for that:
Replace unittest skipTest from transformers (implicitly used by requirement markers) with pytest.skip markers:
require_liger_kernel
require_flash_attn
require_torch_accelerator
require_torch_multi_accelerator
require_wandb