Skip to content

Commit 1757616

Browse files
authored
fix: lru test xfail condition (#1568)
* add xfail condition python 3.13 * linebreaks formatting * fix reason xfail * fix too long lines
1 parent 1210fce commit 1757616

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

tests/embedding_net_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import math
7+
import sys
78

89
import pytest
910
import torch
@@ -550,8 +551,10 @@ def simulator1d(theta):
550551
pytest.param(
551552
"scan",
552553
marks=pytest.mark.xfail(
553-
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
554-
reason="PyTorch's associative_scan only exists for torch >= 2.5",
554+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5)
555+
or sys.version_info >= (3, 13),
556+
reason="PyTorch's associative_scan only exists for torch >= 2.5 \
557+
and Python < 3.13",
555558
strict=True,
556559
),
557560
),
@@ -603,8 +606,10 @@ def test_lru_isolated(
603606
pytest.param(
604607
"scan",
605608
marks=pytest.mark.xfail(
606-
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
607-
reason="PyTorch's associative_scan only exists for torch >= 2.5",
609+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5)
610+
or sys.version_info >= (3, 13),
611+
reason="PyTorch's associative_scan only exists for torch >= 2.5 \
612+
and Python < 3.13",
608613
strict=True,
609614
),
610615
),
@@ -663,8 +668,10 @@ def test_lru_block_isolated(
663668
pytest.param(
664669
"scan",
665670
marks=pytest.mark.xfail(
666-
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
667-
reason="PyTorch's associative_scan only exists for torch >= 2.5",
671+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5)
672+
or sys.version_info >= (3, 13),
673+
reason="PyTorch's associative_scan only exists for torch >= 2.5 \
674+
and Python < 3.13",
668675
strict=True,
669676
),
670677
),
@@ -787,8 +794,10 @@ def _simulator(thetas: Tensor, num_time_steps=500, dt=0.002, eps=0.05) -> Tensor
787794

788795

789796
@pytest.mark.xfail(
790-
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
791-
reason="PyTorch's associative_scan only exists for torch >= 2.5",
797+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5)
798+
or sys.version_info >= (3, 13),
799+
reason="PyTorch's associative_scan only exists for torch >= 2.5 \
800+
and Python < 3.13",
792801
strict=True,
793802
)
794803
@pytest.mark.filterwarnings("ignore:Torchinductor")

0 commit comments

Comments
 (0)