Skip to content

Commit 9d7bc82

Browse files
authored
Move _KINETO_AVAILABLE check to profiler (#18575)
1 parent cbf2e9d commit 9d7bc82

File tree

3 files changed

+5
-10
lines changed

3 files changed

+5
-10
lines changed

src/lightning/pytorch/profilers/pytorch.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,23 @@
2222
import torch
2323
from torch import nn, Tensor
2424
from torch.autograd.profiler import EventList, record_function
25+
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler
26+
from torch.utils.hooks import RemovableHandle
2527

2628
from lightning.fabric.accelerators.cuda import is_cuda_available
2729
from lightning.pytorch.profilers.profiler import Profiler
2830
from lightning.pytorch.utilities.exceptions import MisconfigurationException
29-
from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
3031
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
3132

3233
if TYPE_CHECKING:
33-
from torch.utils.hooks import RemovableHandle
34-
3534
from lightning.pytorch.core.module import LightningModule
3635

37-
if _KINETO_AVAILABLE:
38-
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler
3936

4037
log = logging.getLogger(__name__)
4138
warning_cache = WarningCache()
4239

4340
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
41+
_KINETO_AVAILABLE = torch.profiler.kineto_available()
4442

4543

4644
class RegisterRecordFunction:
@@ -65,7 +63,7 @@ class RegisterRecordFunction:
6563
def __init__(self, model: nn.Module) -> None:
6664
self._model = model
6765
self._records: Dict[str, record_function] = {}
68-
self._handles: Dict[str, List["RemovableHandle"]] = {}
66+
self._handles: Dict[str, List[RemovableHandle]] = {}
6967

7068
def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor:
7169
# Add [pl][module] in name for pytorch profiler to recognize

src/lightning/pytorch/utilities/imports.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@
1515
import functools
1616
import sys
1717

18-
import torch
1918
from lightning_utilities.core.imports import package_available, RequirementCache
2019
from lightning_utilities.core.rank_zero import rank_zero_warn
2120

2221
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
2322
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2423
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2524

26-
_KINETO_AVAILABLE = torch.profiler.kineto_available()
2725
_OMEGACONF_AVAILABLE = package_available("omegaconf")
2826
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
2927
_LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai")

tests/tests_pytorch/profilers/test_profiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
2828
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
2929
from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
30-
from lightning.pytorch.profilers.pytorch import RegisterRecordFunction, warning_cache
30+
from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache
3131
from lightning.pytorch.utilities.exceptions import MisconfigurationException
32-
from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
3332
from tests_pytorch.helpers.runif import RunIf
3433

3534
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005

0 commit comments

Comments
 (0)