Skip to content

Commit 60863a3

Browse files
Add fp4 support (#3532)
1 parent 6c7a8b6 commit 60863a3

File tree

12 files changed

+521
-13
lines changed

12 files changed

+521
-13
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ repos:
2121
- id: clang-format
2222
types_or: [c++, c, cuda]
2323
- repo: https://github.com/keith/pre-commit-buildifier
24-
rev: 6.4.0
24+
rev: 8.0.3
2525
hooks:
2626
- id: buildifier
2727
args:
2828
- --warnings=all
2929
- id: buildifier-lint
3030
- repo: https://github.com/abravalheri/validate-pyproject
31-
rev: v0.23
31+
rev: v0.24.1
3232
hooks:
3333
- id: validate-pyproject
3434
- repo: https://github.com/pycqa/isort
@@ -37,17 +37,17 @@ repos:
3737
- id: isort
3838
name: isort (python)
3939
- repo: https://github.com/pre-commit/mirrors-mypy
40-
rev: "v1.9.0"
40+
rev: "v1.15.0"
4141
hooks:
4242
- id: mypy
4343
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
# Ruff version.
46-
rev: v0.3.3
46+
rev: v0.11.7
4747
hooks:
4848
- id: ruff
4949
- repo: https://github.com/psf/black
50-
rev: 24.3.0
50+
rev: 25.1.0
5151
hooks:
5252
- id: black
5353
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
@@ -57,7 +57,7 @@ repos:
5757
- id: typos
5858
- repo: https://github.com/astral-sh/uv-pre-commit
5959
# uv version.
60-
rev: 0.5.5
60+
rev: 0.7.1
6161
hooks:
6262
# Update the uv lockfile
6363
- id: uv-lock

py/torch_tensorrt/_enums.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ class dtype(Enum):
8080
:meta hide-value:
8181
"""
8282

83+
f4 = auto()
84+
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``
85+
86+
:meta hide-value:
87+
"""
88+
8389
uint8 = u8
8490
int8 = i8
8591

@@ -91,6 +97,9 @@ class dtype(Enum):
9197
float8 = f8
9298
fp8 = f8
9399

100+
float4 = f4
101+
fp4 = f4
102+
94103
half = f16
95104
fp16 = f16
96105
float16 = f16
@@ -162,6 +171,8 @@ def _from(
162171
return dtype.i32
163172
elif t == torch.float8_e4m3fn:
164173
return dtype.f8
174+
elif t == torch.float4_e2m1fn_x2:
175+
return dtype.f4
165176
elif t == torch.half:
166177
return dtype.f16
167178
elif t == torch.float:
@@ -188,6 +199,8 @@ def _from(
188199
return dtype.i8
189200
elif t == trt.DataType.FP8:
190201
return dtype.f8
202+
elif t == trt.DataType.FP4:
203+
return dtype.fp4
191204
elif t == trt.DataType.INT32:
192205
return dtype.i32
193206
elif t == trt.DataType.INT64:
@@ -357,6 +370,8 @@ def to(
357370
return torch.long
358371
elif self == dtype.f8:
359372
return torch.float8_e4m3fn
373+
elif self == dtype.f4:
374+
return torch.float4_e2m1fn_x2
360375
elif self == dtype.f16:
361376
return torch.half
362377
elif self == dtype.f32:
@@ -394,6 +409,8 @@ def to(
394409
return trt.DataType.BOOL
395410
elif self == dtype.bf16:
396411
return trt.DataType.BF16
412+
elif self == dtype.f4:
413+
return trt.DataType.FP4
397414
elif use_default:
398415
return trt.DataType.FLOAT
399416
else:
@@ -410,6 +427,8 @@ def to(
410427
return np.int64
411428
elif self == dtype.f16:
412429
return np.float16
430+
elif self == dtype.f4:
431+
return np.float4_e2m1fn_x2
413432
elif self == dtype.f32:
414433
return np.float32
415434
elif self == dtype.f64:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def cross_compile_for_windows(
257257
x in enabled_precisions for x in {torch.float32, dtype.f32}
258258
):
259259
raise AssertionError(
260-
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
260+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
261261
)
262262

263263
if use_fp32_acc:
@@ -588,7 +588,7 @@ def compile(
588588
x in enabled_precisions for x in {torch.float32, dtype.f32}
589589
):
590590
raise AssertionError(
591-
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
591+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
592592
)
593593

594594
if use_fp32_acc:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929
REQUIRE_FULL_COMPILATION = False
3030
DRYRUN = False
3131
HARDWARE_COMPATIBLE = False
32-
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
32+
SUPPORTED_KERNEL_PRECISIONS = {
33+
dtype.f32,
34+
dtype.f16,
35+
dtype.bf16,
36+
dtype.i8,
37+
dtype.f8,
38+
dtype.f4,
39+
}
3340
TIMING_CACHE_PATH = os.path.join(
3441
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
3542
)

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
2+
from typing import Union
23

34
import numpy as np
5+
import torch
46
from torch_tensorrt.dynamo._settings import CompilationSettings
57
from torch_tensorrt.dynamo.types import TRTNetwork
68

@@ -21,3 +23,9 @@ class ConversionContext:
2123
)
2224
requires_output_allocator: bool = False
2325
mapping: dict[str, np.array] = field(default_factory=dict)
26+
cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field(
27+
default_factory=dict
28+
)
29+
30+
def clear_cpu_weights_reference_holder(self) -> None:
31+
self.cpu_weights_reference_holder.clear()

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,8 @@ def run(
743743
)
744744
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
745745

746+
self.ctx.clear_cpu_weights_reference_holder()
747+
746748
self._save_timing_cache(
747749
builder_config, self.compilation_settings.timing_cache_path
748750
)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,42 @@ def aten_ops_quantize_op(
619619
)
620620

621621

622+
try:
623+
import modelopt.torch.quantization as mtq # noqa: F401
624+
625+
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
626+
except Exception as e:
627+
_LOGGER.warning(
628+
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
629+
)
630+
else:
631+
632+
@dynamo_tensorrt_converter(
633+
torch.ops.tensorrt.dynamic_block_quantize_op.default,
634+
supports_dynamic_shapes=True,
635+
)
636+
def aten_ops_dynamic_block_quantize_op(
637+
ctx: ConversionContext,
638+
target: Target,
639+
args: Tuple[Argument, ...],
640+
kwargs: Dict[str, Argument],
641+
name: str,
642+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
643+
return impl.dynamic_block_quantize.quantize(
644+
ctx,
645+
target,
646+
SourceIR.ATEN,
647+
name,
648+
args[0],
649+
args[1],
650+
args[2],
651+
args[3],
652+
args[4],
653+
args[5],
654+
args[6],
655+
)
656+
657+
622658
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
623659
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
624660
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def create_constant(
326326
name: str,
327327
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]],
328328
min_rank: Optional[int] = 1,
329+
target_quantized_type: Optional[TRTDataType] = None,
329330
) -> TRTTensor:
330331
"""
331332
Add a TensorRT constant layer whose value is `value` to `ctx.net`.
@@ -338,6 +339,7 @@ def create_constant(
338339
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
339340
If a dtype is given, we will convert the type of the given `value` to this dtype.
340341
min_rank (int): minimum rank of the constant tensor.
342+
target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
341343
Returns:
342344
A TensorRT ITensor that represents the given value.
343345
"""
@@ -361,12 +363,48 @@ def create_constant(
361363
shape = list(torch_value.shape)
362364

363365
if torch_value is not None:
366+
if torch_value.dtype == torch.float8_e4m3fn:
367+
weights = trt.Weights(
368+
type=trt.DataType.FP8,
369+
ptr=torch_value.data_ptr(),
370+
count=torch_value.numel(),
371+
)
372+
constant = ctx.net.add_constant(
373+
shape,
374+
weights,
375+
)
376+
constant.name = name
377+
ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value
378+
return constant.get_output(0)
379+
380+
if torch_value.dtype == torch.uint8:
381+
if (
382+
target_quantized_type is None
383+
or target_quantized_type != trt.DataType.FP4
384+
):
385+
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
386+
raise ValueError(
387+
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
388+
)
389+
shape[-1] = shape[-1] * 2
390+
weights = trt.Weights(
391+
type=trt.DataType.FP4,
392+
ptr=torch_value.data_ptr(),
393+
count=torch_value.numel() * 2,
394+
)
395+
constant = ctx.net.add_constant(
396+
shape,
397+
weights,
398+
)
399+
constant.name = name
400+
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
401+
return constant.get_output(0)
402+
364403
if torch_value.dtype == torch.bfloat16:
365404
torch_value_fp32 = torch_value.to(torch.float32)
366405
numpy_value = torch_value_fp32.numpy()
367406
else:
368407
numpy_value = torch_value.numpy()
369-
370408
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
371409
constant = ctx.net.add_constant(
372410
shape,
@@ -381,7 +419,6 @@ def create_constant(
381419
trt.DataType.BF16,
382420
name + "_bf16_cast",
383421
)
384-
385422
return constant.get_output(0)
386423
else:
387424
raise ValueError(
@@ -395,6 +432,7 @@ def get_trt_tensor(
395432
name: str,
396433
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
397434
min_rank: int = 1,
435+
target_quantized_type: Optional[TRTDataType] = None,
398436
) -> TRTTensor:
399437
"""
400438
Given a value of random type, we try to convert it to a TensorRT ITensor.
@@ -408,6 +446,7 @@ def get_trt_tensor(
408446
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
409447
If dtype is provided, the given value will be converted to this dtype.
410448
min_rank (int): minimum rank of the constant tensor.
449+
target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
411450
Returns:
412451
A TensorRT ITensor that represents the given value.
413452
"""
@@ -420,7 +459,9 @@ def get_trt_tensor(
420459
input_val = input_val.astype(np.float32)
421460

422461
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):
423-
return create_constant(ctx, input_val, name, dtype, min_rank)
462+
return create_constant(
463+
ctx, input_val, name, dtype, min_rank, target_quantized_type
464+
)
424465
elif isinstance(input_val, TRTTensor):
425466
return input_val
426467
else:

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
condition,
88
conv,
99
deconv,
10+
dynamic_block_quantize,
1011
elementwise,
1112
embedding,
1213
full,

0 commit comments

Comments
 (0)