Skip to content

Commit dd882dc

Browse files
famuraMatthijspals
andauthored
tests: torch version xfail condition for LRU for 'scan' (#1552)
* Changed the xfail condition * mark slow test * ruff * add mark parameterize bidirectional * mark slow test * mark more slow tests * ruff * flip sign condition * flip signs condition --------- Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com> Co-authored-by: Matthijspals <matthijs-pals@hotmail.com>
1 parent d6fb040 commit dd882dc

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tests/embedding_net_test.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from __future__ import annotations
55

66
import math
7-
import sys
87

98
import pytest
109
import torch
@@ -355,6 +354,7 @@ def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50):
355354
@pytest.mark.parametrize("num_channels", (1, 2, 3))
356355
@pytest.mark.parametrize("change_c_mode", ["conv", "zeros"])
357356
@pytest.mark.parametrize("n_stages", [1, 3, 4])
357+
@pytest.mark.slow
358358
def test_2d_ResNet_cnn_embedding_net(
359359
input_shape, num_channels, change_c_mode, n_stages
360360
):
@@ -439,24 +439,26 @@ def simulator1d(theta):
439439
posterior.potential(s)
440440

441441

442-
@pytest.mark.parametrize(
443-
"bidirectional", [True, False], ids=["one-directional", "bi-directional"]
444-
)
445442
@pytest.mark.parametrize(
446443
"mode",
447444
[
448445
"loop",
449446
pytest.param(
450447
"scan",
451448
marks=pytest.mark.xfail(
452-
condition=sys.version_info >= (3, 13),
453-
reason="torch.compiler is not yet supported on Python >= 3.13",
449+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
450+
reason="PyTorch's associative_scan only exists for torch >= 2.5",
454451
strict=True,
455452
),
456453
),
457454
],
458455
ids=["loop", "scan"],
459456
)
457+
@pytest.mark.parametrize(
458+
"bidirectional", [True, False], ids=["one-directional", "bi-directional"]
459+
)
460+
@pytest.mark.slow
461+
@pytest.mark.filterwarnings("ignore:Torchinductor")
460462
def test_lru_isolated(
461463
bidirectional: bool,
462464
mode: str,
@@ -497,8 +499,8 @@ def test_lru_isolated(
497499
pytest.param(
498500
"scan",
499501
marks=pytest.mark.xfail(
500-
condition=sys.version_info >= (3, 13),
501-
reason="torch.compiler is not yet supported on Python >= 3.13",
502+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
503+
reason="PyTorch's associative_scan only exists for torch >= 2.5",
502504
strict=True,
503505
),
504506
),
@@ -510,6 +512,8 @@ def test_lru_isolated(
510512
[True, False],
511513
ids=["input-normalization", "no-input-normalization"],
512514
)
515+
@pytest.mark.slow
516+
@pytest.mark.filterwarnings("ignore:Torchinductor")
513517
def test_lru_block_isolated(
514518
bidirectional: bool,
515519
mode: str,
@@ -555,8 +559,8 @@ def test_lru_block_isolated(
555559
pytest.param(
556560
"scan",
557561
marks=pytest.mark.xfail(
558-
condition=sys.version_info >= (3, 13),
559-
reason="torch.compiler is not yet supported on Python >= 3.13",
562+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
563+
reason="PyTorch's associative_scan only exists for torch >= 2.5",
560564
strict=True,
561565
),
562566
),
@@ -566,6 +570,8 @@ def test_lru_block_isolated(
566570
@pytest.mark.parametrize(
567571
"aggregate_fcn", ["last_step", "mean"], ids=["last-step", "mean"]
568572
)
573+
@pytest.mark.slow
574+
@pytest.mark.filterwarnings("ignore:Torchinductor")
569575
def test_lru_embedding_net_isolated(
570576
bidirectional: bool,
571577
mode: str,
@@ -677,10 +683,11 @@ def _simulator(thetas: Tensor, num_time_steps=500, dt=0.002, eps=0.05) -> Tensor
677683

678684

679685
@pytest.mark.xfail(
680-
condition=sys.version_info >= (3, 13),
681-
reason="torch.compiler is not yet supported on Python >= 3.13",
686+
condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5),
687+
reason="PyTorch's associative_scan only exists for torch >= 2.5",
682688
strict=True,
683689
)
690+
@pytest.mark.filterwarnings("ignore:Torchinductor")
684691
def test_scan(
685692
input_dim: int = 3,
686693
output_dim: int = 3,

0 commit comments

Comments
 (0)