Skip to content

Commit 37d4e90

Browse files
committed
Addressing review comments- tmp dir for wheel download and wheel extraction, variable for py_version
1 parent 23d27b0 commit 37d4e90

File tree

1 file changed

+120
-59
lines changed

1 file changed

+120
-59
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 120 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@
22

33
import ctypes
44
import gc
5+
import getpass
56
import logging
67
import os
8+
import tempfile
79
import urllib.request
810
import warnings
11+
from contextlib import contextmanager
912
from dataclasses import fields, replace
1013
from enum import Enum
11-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
14+
from pathlib import Path
15+
from typing import (
16+
Any,
17+
Callable,
18+
Dict,
19+
Iterator,
20+
List,
21+
Optional,
22+
Sequence,
23+
Tuple,
24+
Union,
25+
)
1226

1327
import numpy as np
1428
import sympy
@@ -37,6 +51,7 @@
3751
RTOL = 5e-3
3852
ATOL = 5e-3
3953
CPU_DEVICE = "cpu"
54+
_WHL_CPYTHON_VERSION = "3.10"
4055

4156

4257
class Frameworks(Enum):
@@ -823,17 +838,39 @@ def is_tegra_platform() -> bool:
823838
return False
824839

825840

826-
def download_plugin_lib_path(py_version: str, platform: str) -> str:
827-
plugin_lib_path = None
841+
@contextmanager
842+
def download_plugin_lib_path(platform: str) -> Iterator[str]:
843+
"""
844+
Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
845+
then yields the path to the extracted shared library (.so or .dll).
846+
847+
The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
848+
Extraction happens in a temporary directory that is cleaned up after use.
849+
850+
Args:
851+
platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
828852
829-
# Downloading TRT-LLM lib
830-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
831-
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl"
832-
download_url = base_url + file_name
833-
if not (os.path.exists(file_name)):
853+
Yields:
854+
str: The full path to the extracted TensorRT-LLM shared library file.
855+
856+
Raises:
857+
ImportError: If the 'zipfile' module is not available.
858+
"""
859+
plugin_lib_path = None
860+
username = getpass.getuser()
861+
torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
862+
torchtrt_cache_dir.mkdir(parents=True, exist_ok=True)
863+
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl"
864+
torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
865+
866+
if not torchtrt_cache_trtllm_whl.exists():
867+
# Downloading TRT-LLM lib
868+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
869+
download_url = base_url + file_name
870+
downloaded_file_path = torchtrt_cache_trtllm_whl
834871
try:
835872
logger.debug(f"Downloading {download_url} ...")
836-
urllib.request.urlretrieve(download_url, file_name)
873+
urllib.request.urlretrieve(download_url, downloaded_file_path)
837874
logger.debug("Download succeeded and TRT-LLM wheel is now present")
838875
except urllib.error.HTTPError as e:
839876
logger.error(
@@ -846,60 +883,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
846883
except OSError as e:
847884
logger.error(f"Local file write error: {e}")
848885

849-
# Proceeding with the unzip of the wheel file
850-
# This will exist if the filename was already downloaded
886+
# Proceeding with the unzip of the wheel file in tmpdir
851887
if "linux" in platform:
852888
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
853889
else:
854890
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
855-
plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename)
856-
if os.path.exists(plugin_lib_path):
857-
return plugin_lib_path
858-
try:
859-
import zipfile
860-
except ImportError as e:
861-
raise ImportError(
862-
"zipfile module is required but not found. Please install zipfile"
863-
)
864-
with zipfile.ZipFile(file_name, "r") as zip_ref:
865-
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
866-
plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
867-
return plugin_lib_path
868891

892+
with tempfile.TemporaryDirectory() as tmpdir:
893+
try:
894+
import zipfile
895+
except ImportError:
896+
raise ImportError(
897+
"zipfile module is required but not found. Please install zipfile"
898+
)
899+
try:
900+
with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref:
901+
zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm'
902+
except FileNotFoundError as e:
903+
# This should capture the errors in the download failure above
904+
logger.error(f"Wheel file not found at {downloaded_file_path}: {e}")
905+
raise RuntimeError(
906+
f"Failed to find downloaded wheel file at {downloaded_file_path}"
907+
) from e
908+
except zipfile.BadZipFile as e:
909+
logger.error(f"Invalid or corrupted wheel file: {e}")
910+
raise RuntimeError(
911+
"Downloaded wheel file is corrupted or not a valid zip archive"
912+
) from e
913+
except Exception as e:
914+
logger.error(f"Unexpected error while extracting wheel: {e}")
915+
raise RuntimeError(
916+
"Unexpected error during extraction of TensorRT-LLM wheel"
917+
) from e
918+
plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename)
919+
yield plugin_lib_path
920+
921+
922+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
923+
"""
924+
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
869925
870-
def load_tensorrt_llm() -> bool:
871-
"""
872-
Attempts to load the TensorRT-LLM plugin and initialize it.
873-
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
874-
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
926+
Args:
927+
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
875928
876929
Returns:
877-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
930+
bool: True if successful, False otherwise.
878931
"""
879-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
880-
if not plugin_lib_path:
881-
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
882-
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
883-
"1",
884-
"true",
885-
"yes",
886-
"on",
887-
)
888-
if not use_trtllm_plugin:
889-
logger.warning(
890-
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
891-
)
892-
return False
893-
else:
894-
# this is used as the default py version
895-
py_version = "cp310"
896-
platform = Platform.current_platform()
897-
898-
platform = str(platform).lower()
899-
plugin_lib_path = download_plugin_lib_path(py_version, platform)
900-
901932
try:
902-
# Load the shared TRT-LLM file
903933
handle = ctypes.CDLL(plugin_lib_path)
904934
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
905935
except OSError as e_os_error:
@@ -912,14 +942,13 @@ def load_tensorrt_llm() -> bool:
912942
)
913943
else:
914944
logger.warning(
915-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
916-
f"Ensure the path is correct and the library is compatible",
945+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
946+
f"Ensure the path is correct and the library is compatible.",
917947
exc_info=e_os_error,
918948
)
919949
return False
920950

921951
try:
922-
# Configure plugin initialization arguments
923952
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
924953
handle.initTrtLlmPlugins.restype = ctypes.c_bool
925954
except AttributeError as e_plugin_unavailable:
@@ -930,9 +959,7 @@ def load_tensorrt_llm() -> bool:
930959
return False
931960

932961
try:
933-
# Initialize the plugin
934-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
935-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
962+
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
936963
logger.info("TensorRT-LLM plugin successfully initialized")
937964
return True
938965
else:
@@ -945,3 +972,37 @@ def load_tensorrt_llm() -> bool:
945972
)
946973
return False
947974
return False
975+
976+
977+
def load_tensorrt_llm() -> bool:
978+
"""
979+
Attempts to load the TensorRT-LLM plugin and initialize it.
980+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
981+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
982+
983+
Returns:
984+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
985+
"""
986+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
987+
if plugin_lib_path:
988+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
989+
else:
990+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
991+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
992+
"1",
993+
"true",
994+
"yes",
995+
"on",
996+
)
997+
if not use_trtllm_plugin:
998+
logger.warning(
999+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1000+
)
1001+
return False
1002+
else:
1003+
platform = Platform.current_platform()
1004+
platform = str(platform).lower()
1005+
1006+
with download_plugin_lib_path(platform) as plugin_lib_path:
1007+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1008+
return False

0 commit comments

Comments
 (0)