2
2
3
3
import ctypes
4
4
import gc
5
+ import getpass
5
6
import logging
6
7
import os
8
+ import tempfile
7
9
import urllib .request
8
10
import warnings
11
+ from contextlib import contextmanager
9
12
from dataclasses import fields , replace
10
13
from enum import Enum
11
- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
14
+ from pathlib import Path
15
+ from typing import (
16
+ Any ,
17
+ Callable ,
18
+ Dict ,
19
+ Iterator ,
20
+ List ,
21
+ Optional ,
22
+ Sequence ,
23
+ Tuple ,
24
+ Union ,
25
+ )
12
26
13
27
import numpy as np
14
28
import sympy
37
51
RTOL = 5e-3
38
52
ATOL = 5e-3
39
53
CPU_DEVICE = "cpu"
54
+ _WHL_CPYTHON_VERSION = "3.10"
40
55
41
56
42
57
class Frameworks (Enum ):
@@ -823,17 +838,39 @@ def is_tegra_platform() -> bool:
823
838
return False
824
839
825
840
826
- def download_plugin_lib_path (py_version : str , platform : str ) -> str :
827
- plugin_lib_path = None
841
+ @contextmanager
842
+ def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
843
+ """
844
+ Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
845
+ then yields the path to the extracted shared library (.so or .dll).
846
+
847
+ The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
848
+ Extraction happens in a temporary directory that is cleaned up after use.
849
+
850
+ Args:
851
+ platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
828
852
829
- # Downloading TRT-LLM lib
830
- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
831
- file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ py_version } -{ py_version } -{ platform } .whl"
832
- download_url = base_url + file_name
833
- if not (os .path .exists (file_name )):
853
+ Yields:
854
+ str: The full path to the extracted TensorRT-LLM shared library file.
855
+
856
+ Raises:
857
+ ImportError: If the 'zipfile' module is not available.
858
+ """
859
+ plugin_lib_path = None
860
+ username = getpass .getuser ()
861
+ torchtrt_cache_dir = Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
862
+ torchtrt_cache_dir .mkdir (parents = True , exist_ok = True )
863
+ file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
864
+ torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
865
+
866
+ if not torchtrt_cache_trtllm_whl .exists ():
867
+ # Downloading TRT-LLM lib
868
+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
869
+ download_url = base_url + file_name
870
+ downloaded_file_path = torchtrt_cache_trtllm_whl
834
871
try :
835
872
logger .debug (f"Downloading { download_url } ..." )
836
- urllib .request .urlretrieve (download_url , file_name )
873
+ urllib .request .urlretrieve (download_url , downloaded_file_path )
837
874
logger .debug ("Download succeeded and TRT-LLM wheel is now present" )
838
875
except urllib .error .HTTPError as e :
839
876
logger .error (
@@ -846,60 +883,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
846
883
except OSError as e :
847
884
logger .error (f"Local file write error: { e } " )
848
885
849
- # Proceeding with the unzip of the wheel file
850
- # This will exist if the filename was already downloaded
886
+ # Proceeding with the unzip of the wheel file in tmpdir
851
887
if "linux" in platform :
852
888
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
853
889
else :
854
890
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
855
- plugin_lib_path = os .path .join ("./tensorrt_llm/libs" , lib_filename )
856
- if os .path .exists (plugin_lib_path ):
857
- return plugin_lib_path
858
- try :
859
- import zipfile
860
- except ImportError as e :
861
- raise ImportError (
862
- "zipfile module is required but not found. Please install zipfile"
863
- )
864
- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
865
- zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
866
- plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
867
- return plugin_lib_path
868
891
892
+ with tempfile .TemporaryDirectory () as tmpdir :
893
+ try :
894
+ import zipfile
895
+ except ImportError :
896
+ raise ImportError (
897
+ "zipfile module is required but not found. Please install zipfile"
898
+ )
899
+ try :
900
+ with zipfile .ZipFile (downloaded_file_path , "r" ) as zip_ref :
901
+ zip_ref .extractall (tmpdir ) # Extract to a folder named 'tensorrt_llm'
902
+ except FileNotFoundError as e :
903
+ # This should capture the errors in the download failure above
904
+ logger .error (f"Wheel file not found at { downloaded_file_path } : { e } " )
905
+ raise RuntimeError (
906
+ f"Failed to find downloaded wheel file at { downloaded_file_path } "
907
+ ) from e
908
+ except zipfile .BadZipFile as e :
909
+ logger .error (f"Invalid or corrupted wheel file: { e } " )
910
+ raise RuntimeError (
911
+ "Downloaded wheel file is corrupted or not a valid zip archive"
912
+ ) from e
913
+ except Exception as e :
914
+ logger .error (f"Unexpected error while extracting wheel: { e } " )
915
+ raise RuntimeError (
916
+ "Unexpected error during extraction of TensorRT-LLM wheel"
917
+ ) from e
918
+ plugin_lib_path = os .path .join (tmpdir , "tensorrt_llm/libs" , lib_filename )
919
+ yield plugin_lib_path
920
+
921
+
922
+ def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
923
+ """
924
+ Loads and initializes the TensorRT-LLM plugin from the given shared library path.
869
925
870
- def load_tensorrt_llm () -> bool :
871
- """
872
- Attempts to load the TensorRT-LLM plugin and initialize it.
873
- Either the env variable TRTLLM_PLUGINS_PATH can specify the path
874
- Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
926
+ Args:
927
+ plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
875
928
876
929
Returns:
877
- bool: True if the plugin was successfully loaded and initialized , False otherwise.
930
+ bool: True if successful , False otherwise.
878
931
"""
879
- plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
880
- if not plugin_lib_path :
881
- # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
882
- use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
883
- "1" ,
884
- "true" ,
885
- "yes" ,
886
- "on" ,
887
- )
888
- if not use_trtllm_plugin :
889
- logger .warning (
890
- "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
891
- )
892
- return False
893
- else :
894
- # this is used as the default py version
895
- py_version = "cp310"
896
- platform = Platform .current_platform ()
897
-
898
- platform = str (platform ).lower ()
899
- plugin_lib_path = download_plugin_lib_path (py_version , platform )
900
-
901
932
try :
902
- # Load the shared TRT-LLM file
903
933
handle = ctypes .CDLL (plugin_lib_path )
904
934
logger .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
905
935
except OSError as e_os_error :
@@ -912,14 +942,13 @@ def load_tensorrt_llm() -> bool:
912
942
)
913
943
else :
914
944
logger .warning (
915
- f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } "
916
- f"Ensure the path is correct and the library is compatible" ,
945
+ f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } . "
946
+ f"Ensure the path is correct and the library is compatible. " ,
917
947
exc_info = e_os_error ,
918
948
)
919
949
return False
920
950
921
951
try :
922
- # Configure plugin initialization arguments
923
952
handle .initTrtLlmPlugins .argtypes = [ctypes .c_void_p , ctypes .c_char_p ]
924
953
handle .initTrtLlmPlugins .restype = ctypes .c_bool
925
954
except AttributeError as e_plugin_unavailable :
@@ -930,9 +959,7 @@ def load_tensorrt_llm() -> bool:
930
959
return False
931
960
932
961
try :
933
- # Initialize the plugin
934
- TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
935
- if handle .initTrtLlmPlugins (None , TRT_LLM_PLUGIN_NAMESPACE .encode ("utf-8" )):
962
+ if handle .initTrtLlmPlugins (None , b"tensorrt_llm" ):
936
963
logger .info ("TensorRT-LLM plugin successfully initialized" )
937
964
return True
938
965
else :
@@ -945,3 +972,37 @@ def load_tensorrt_llm() -> bool:
945
972
)
946
973
return False
947
974
return False
975
+
976
+
977
+ def load_tensorrt_llm () -> bool :
978
+ """
979
+ Attempts to load the TensorRT-LLM plugin and initialize it.
980
+ Either the env variable TRTLLM_PLUGINS_PATH can specify the path
981
+ Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
982
+
983
+ Returns:
984
+ bool: True if the plugin was successfully loaded and initialized, False otherwise.
985
+ """
986
+ plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
987
+ if plugin_lib_path :
988
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
989
+ else :
990
+ # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
991
+ use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
992
+ "1" ,
993
+ "true" ,
994
+ "yes" ,
995
+ "on" ,
996
+ )
997
+ if not use_trtllm_plugin :
998
+ logger .warning (
999
+ "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1000
+ )
1001
+ return False
1002
+ else :
1003
+ platform = Platform .current_platform ()
1004
+ platform = str (platform ).lower ()
1005
+
1006
+ with download_plugin_lib_path (platform ) as plugin_lib_path :
1007
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1008
+ return False
0 commit comments