Skip to content

Commit c8ad852

Browse files
committed
Remove all uses of nvFuser legacy bindings
as a follow-up to #2812
1 parent 1e22aac commit c8ad852

File tree

5 files changed

+13
-32
lines changed

5 files changed

+13
-32
lines changed

dockers/ubuntu-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,5 +158,5 @@ RUN \
158158
pip list && \
159159
python -c "import sys; ver = sys.version_info ; assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \
160160
python -c "import torch; print(f'PyTorch=={torch.__version__} with {torch.cuda.device_count()} GPUs')" && \
161-
python -c "import nvfuser; print(f'nvFuser=={nvfuser.version()}')" && \
161+
python -c "import nvfuser_direct as nvfuser; print(f'nvFuser=={nvfuser.version()}')" && \
162162
python -c "import triton; print(f'Triton=={triton.__version__}')"

docs/source/basic/inspecting_traces.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ This will print the following::
206206
# cuda version: 12.1
207207
# nvfuser version: 0.2.8
208208
import torch
209-
from nvfuser import FusionDefinition, DataType
209+
from nvfuser_direct import FusionDefinition, DataType
210210

211211
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
212212
T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])

thunder/executors/nvfuserex.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,12 @@ def nvfuser_version() -> LooseVersion | None:
1818
try:
1919
import nvfuser_direct
2020
except ImportError:
21-
try:
22-
import nvfuser
23-
except ImportError:
24-
pass
25-
else:
26-
if hasattr(nvfuser, "version"):
27-
return LooseVersion(nvfuser.version())
28-
else:
29-
# NOTE: This import of nvFuser may or may not have version info
30-
return LooseVersion("0.0.0")
21+
return None
3122
else:
3223
if hasattr(nvfuser_direct, "version"):
3324
return LooseVersion(nvfuser_direct.version())
3425
else:
3526
return LooseVersion("0.0.0")
36-
# NOTE This occurs when nvFuser couldn't be imported
37-
return None
3827

3928

4029
def required_nvfuser_version() -> LooseVersion:

thunder/executors/nvfuserex_impl.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,16 @@
7171
# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
7272
# by nvfuserex.py when nvFuser is available.
7373

74-
DIRECT_BINDINGS_SUPPORTED_VERSION = LooseVersion("0.2.34")
7574
DTENSOR_SUPPORTED_VERSION = LooseVersion("0.2.28")
76-
if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION:
77-
import nvfuser_direct as nvfuser
78-
from nvfuser_direct import (
79-
DataType,
80-
FusionDefinition,
81-
multidevice,
82-
ParallelType,
83-
execute_with_dtensors,
84-
compute_tensor_descriptor as nv_compute_td,
85-
)
86-
else:
87-
if nvfuser_version() >= DTENSOR_SUPPORTED_VERSION:
88-
from nvfuser_direct import FusionDefinition as DirectFusionDefinition
89-
from nvfuser_direct import multidevice, ParallelType, execute_with_dtensors
90-
import nvfuser
91-
from nvfuser import DataType, FusionDefinition, compute_tensor_descriptor as nv_compute_td
75+
import nvfuser_direct as nvfuser
76+
from nvfuser_direct import (
77+
DataType,
78+
FusionDefinition,
79+
multidevice,
80+
ParallelType,
81+
execute_with_dtensors,
82+
compute_tensor_descriptor as nv_compute_td,
83+
)
9284

9385
#
9486
# Helper functions

thunder/tests/test_dynamo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,7 @@ def foo(x):
14031403
@pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2546")
14041404
@requiresCUDA
14051405
def test_WallTime_KernelTime():
1406-
from nvfuser import FusionDefinition, DataType
1406+
from nvfuser_direct import FusionDefinition, DataType
14071407

14081408
def nvfuser_fusion_id2(fd: FusionDefinition) -> None:
14091409
T0 = fd.define_tensor(

0 commit comments

Comments
 (0)