Skip to content

Commit cce06ec

Browse files
refactor: use __all__ in accelerators/__init__.py (#20889)
1 parent d195d2b commit cce06ec

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/lightning/pytorch/accelerators/__init__.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,26 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13+
14+
__all__ = [
15+
"Accelerator",
16+
"CPUAccelerator",
17+
"CUDAAccelerator",
18+
"MPSAccelerator",
19+
"XLAAccelerator",
20+
"find_usable_cuda_devices",
21+
]
22+
1323
import sys
1424

15-
from lightning.fabric.accelerators import find_usable_cuda_devices # noqa: F401
25+
from lightning.fabric.accelerators import find_usable_cuda_devices
1626
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
1727
from lightning.fabric.utilities.registry import _register_classes
1828
from lightning.pytorch.accelerators.accelerator import Accelerator
19-
from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401
20-
from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401
21-
from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401
22-
from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401
29+
from lightning.pytorch.accelerators.cpu import CPUAccelerator
30+
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
31+
from lightning.pytorch.accelerators.mps import MPSAccelerator
32+
from lightning.pytorch.accelerators.xla import XLAAccelerator
2333

2434
AcceleratorRegistry = _AcceleratorRegistry()
2535
_register_classes(AcceleratorRegistry, "register_accelerators", sys.modules[__name__], Accelerator)

0 commit comments

Comments
 (0)