|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
6 | 6 | import math
|
| 7 | +import sys |
7 | 8 |
|
8 | 9 | import pytest
|
9 | 10 | import torch
|
@@ -550,8 +551,10 @@ def simulator1d(theta):
|
550 | 551 | pytest.param(
|
551 | 552 | "scan",
|
552 | 553 | marks=pytest.mark.xfail(
|
553 |
| - condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5), |
554 |
| - reason="PyTorch's associative_scan only exists for torch >= 2.5", |
| 554 | + condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5) |
| 555 | + or sys.version_info >= (3, 13), |
| 556 | + reason="PyTorch's associative_scan only exists for torch >= 2.5 \ |
| 557 | + and Python < 3.13", |
555 | 558 | strict=True,
|
556 | 559 | ),
|
557 | 560 | ),
|
@@ -603,8 +606,10 @@ def test_lru_isolated(
|
603 | 606 | pytest.param(
|
604 | 607 | "scan",
|
605 | 608 | marks=pytest.mark.xfail(
|
606 |
| - condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5), |
607 |
| - reason="PyTorch's associative_scan only exists for torch >= 2.5", |
| 609 | + condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5) |
| 610 | + or sys.version_info >= (3, 13), |
| 611 | + reason="PyTorch's associative_scan only exists for torch >= 2.5 \ |
| 612 | + and Python < 3.13", |
608 | 613 | strict=True,
|
609 | 614 | ),
|
610 | 615 | ),
|
@@ -663,8 +668,10 @@ def test_lru_block_isolated(
|
663 | 668 | pytest.param(
|
664 | 669 | "scan",
|
665 | 670 | marks=pytest.mark.xfail(
|
666 |
| - condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5), |
667 |
| - reason="PyTorch's associative_scan only exists for torch >= 2.5", |
| 671 | + condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5) |
| 672 | + or sys.version_info >= (3, 13), |
| 673 | + reason="PyTorch's associative_scan only exists for torch >= 2.5 \ |
| 674 | + and Python < 3.13", |
668 | 675 | strict=True,
|
669 | 676 | ),
|
670 | 677 | ),
|
@@ -787,8 +794,10 @@ def _simulator(thetas: Tensor, num_time_steps=500, dt=0.002, eps=0.05) -> Tensor
|
787 | 794 |
|
788 | 795 |
|
789 | 796 | @pytest.mark.xfail(
|
790 |
| - condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5), |
791 |
| - reason="PyTorch's associative_scan only exists for torch >= 2.5", |
| 797 | + condition=tuple(map(int, torch.__version__.split('.')[:2])) < (2, 5) |
| 798 | + or sys.version_info >= (3, 13), |
| 799 | + reason="PyTorch's associative_scan only exists for torch >= 2.5 \ |
| 800 | + and Python < 3.13", |
792 | 801 | strict=True,
|
793 | 802 | )
|
794 | 803 | @pytest.mark.filterwarnings("ignore:Torchinductor")
|
|
0 commit comments