From e5723dbcfeecd0ed39ba5c227619813cc2db6cf6 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 25 Apr 2025 21:03:59 +0000 Subject: [PATCH 01/30] chore: kv cache prototyping --- .../dynamo/lowering/_decomposition_groups.py | 1 + .../dynamo/lowering/_decompositions.py | 24 +++++++++---------- .../lowering/passes/_aten_lowering_pass.py | 2 ++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 825be75076..6df05f6940 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -171,6 +171,7 @@ aten.upsample_bilinear2d.vec, aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, + aten.linear, } diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 8037858151..2601c35e4a 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -442,9 +442,9 @@ def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tenso return aten._reshape_copy.default(x, size) -@register_torch_trt_decomposition( - aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS -) +# @register_torch_trt_decomposition( +# aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) def scaled_dot_product_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -488,9 +488,9 @@ def scaled_dot_product_attention_decomposition( return attn_weight @ value -@register_torch_trt_decomposition( - aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS -) +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) def scaled_dot_product_flash_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -517,9 +517,9 @@ def scaled_dot_product_flash_attention_decomposition( return attn, None, None, None, 0, 0, None, None, None -@register_torch_trt_decomposition( - aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS -) +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) def scaled_dot_product_efficient_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -537,9 +537,9 @@ def scaled_dot_product_efficient_attention_decomposition( return attn, None, None, None -@register_torch_trt_decomposition( - aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS -) +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) def scaled_dot_product_cudnn_attention_decomposition( query: torch.Tensor, key: torch.Tensor, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..e47ecb6191 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast @@ -25,6 +26,7 @@ replace_max_pool_with_indices, remove_assert_nodes, accumulate_fp32_matmul, + lower_scaled_dot_product_attention, remove_num_users_is_0_nodes, ] From 3e0b46a5cfeb92b12d72bdc2dc739c0f57544832 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sun, 27 Apr 2025 00:47:03 +0000 Subject: [PATCH 02/30] chore: add sdpa converter/lowering --- examples/dynamo/utils.py | 51 +++++- .../dynamo/conversion/aten_ops_converters.py | 30 ++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/attention.py | 165 +++++++++++++++++ .../lower_scaled_dot_product_attention.py | 169 ++++++++++++++++++ 5 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/attention.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 25ad99c12d..496eadb5c9 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -4,7 +4,9 @@ EosTokenCriteria, MaxLengthCriteria, ) - +import numpy as np +import copy +import timeit def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): """ @@ -61,3 +63,50 @@ def generate(model, input_seq, max_tokens, eos_token_id): break return input_seq + + +def time_generate( + generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + inputs_copy = copy.copy(inputs) + _ = generate_fn( + model, inputs_copy, output_seq_length, eos_token_id + ) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings + + +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): + """ + Records different timing stats and adds it to the result + """ + times = np.array(timings) + speeds = batch_size / times + time_mean = np.mean(times).item() + time_med = np.median(times).item() + time_99th = np.percentile(times, 99).item() + time_std = np.std(times, ddof=0).item() + speed_mean = np.mean(speeds).item() + speed_med = np.median(speeds).item() + + stats = { + "Backend": backend, + "Precision": precision, + "Batch size": batch_size, + "Median(FPS)": speed_med, + "Mean(FPS)": speed_mean, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, + } + return stats \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..05b4582191 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1896,6 +1896,36 @@ def aten_ops_minimum( args[1], ) +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + # Currently, `attn_mask` is not supported + return args_bounds_check(node.args, 3) is None + +@dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + capability_validator=attention_validator, + supports_dynamic_shapes=True, +) +def tensorrt_scaled_dot_product_attention( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.attention.scaled_dot_product_attention( + ctx, + target, + SourceIR.TORCHTRT_LOWERED, + name, + args[0], + args[1], + args[2], + args_bounds_check(args, 5, False), + kwargs.get("scale", None), + ) + @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..a8b0fbe284 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -1,5 +1,6 @@ from torch_tensorrt.dynamo.conversion.impl import ( activation, + attention, addmm, arange, cast, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py new file mode 100644 index 0000000000..71dfb5f818 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -0,0 +1,165 @@ +import math +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # the lower triangle of the tensor means the rows greater than and equal to the cols + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 + ) + # get the rows + row_tensor = impl.elementwise.trunc_div( + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col + ) + # get the cols + col_tensor = impl.elementwise.fmod( + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col + ) + cond = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor + ) + return impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", cond, [row, col] + ) + + +def scaled_dot_product_attention( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + query: TRTTensor, + key: TRTTensor, + value: TRTTensor, + is_causal: bool, + scale: Optional[float], +) -> TRTTensor: + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, -2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) + + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) + + # this is to generate a tensor which has shape (L, S), type is int32 + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 + ) + shape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] + ) + + # since we want our attn_bias to be in float32, so cast it to float32 + shape_tensor = cast_trt_tensor( + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir + ) + + # initialize the attn_bias as the zeros tensor + attn_bias = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 + ) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + inf_tensor = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") + ) + cond = impl.elementwise.eq( + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) + ) + # mask out the certain part of the attn_bias + attn_bias = impl.condition.select( + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + ) + + scaled = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..ee7651cb8c --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,169 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement \ No newline at end of file From e30fa42b92915148d981e8879e68b72d867ef6be Mon Sep 17 00:00:00 2001 From: Chengzhe Xu Date: Wed, 7 May 2025 22:42:56 +0000 Subject: [PATCH 03/30] feat: implement static/dynamic kv cache in Torch-TRT --- examples/dynamo/llama3_trt.py | 250 +++++++++++++++++ examples/dynamo/llama_benchmark.py | 77 ++++++ examples/dynamo/utils.py | 42 ++- py/torch_tensorrt/dynamo/_compiler.py | 26 ++ .../runtime/_PythonTorchTensorRTModule.py | 6 +- py/torch_tensorrt/dynamo/utils.py | 4 +- py/torch_tensorrt/extensions/__init__.py | 1 + py/torch_tensorrt/extensions/hf/__init__.py | 2 + .../extensions/hf/dynamic_cache.py | 249 +++++++++++++++++ .../extensions/hf/static_cache.py | 258 ++++++++++++++++++ py/torch_tensorrt/extensions/hf/utils.py | 152 +++++++++++ 11 files changed, 1055 insertions(+), 12 deletions(-) create mode 100644 examples/dynamo/llama3_trt.py create mode 100644 examples/dynamo/llama_benchmark.py create mode 100644 py/torch_tensorrt/extensions/__init__.py create mode 100644 py/torch_tensorrt/extensions/hf/__init__.py create mode 100644 py/torch_tensorrt/extensions/hf/dynamic_cache.py create mode 100644 py/torch_tensorrt/extensions/hf/static_cache.py create mode 100644 py/torch_tensorrt/extensions/hf/utils.py diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py new file mode 100644 index 0000000000..79281ffcf1 --- /dev/null +++ b/examples/dynamo/llama3_trt.py @@ -0,0 +1,250 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the dynamo backend +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" + +import argparse +import copy +import os +import timeit + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList +from contextlib import nullcontext +from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache + +MAX_TOKENS = 128 +DEVICE = torch.device("cuda:0") + +def get_model(args): + with torch.no_grad(): + if args.model == "meta-llama/Llama-3.2-1B-Instruct": + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ) + .eval() + .half() + .cuda() + ) + + elif args.model == "meta-llama/Llama-3.2-3B-Instruct": + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + # num_hidden_layers=2 + ) + .eval() + .half() + .cuda() + ) + elif args.model == "meta-llama/Llama-3.1-8B-Instruct": + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", # num_hidden_layers=1 + ) + .eval() + .half() + .cuda() + ) + elif args.model == "google/gemma-3-1b-it": + model = ( + AutoModelForCausalLM.from_pretrained( + "google/gemma-3-1b-it", use_cache=False, attn_implementation="sdpa" + ) + .eval() + .half() + .cuda() + ) + model = model.to(torch.float16) + return model + + +def compile_torchtrt(model, input_ids, min_block_size=1, debug=False): + max_seq_len = input_ids.shape[1] + MAX_TOKENS + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + + with (torch_tensorrt.logging.debug() if debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids], + enabled_precisions={torch.float16}, + # truncate_double=True, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=debug, + min_block_size=min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + + + print(f"============================= {backend_name} ==============================") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("=============================") + +def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + kv_cache_inputs = placeholder_nodes[1:-2] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=DEVICE)) + + return tuple(zeroed_kv_cache_inputs) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model" + ) + arg_parser.add_argument( + "--tokenizer_path", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Prompt") + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--disable_pytorch_run", + action="store_false", + help="Disable pytorch run (default: True)" + ) + arg_parser.add_argument( + "--kv_cache", + action="store_true", + help="Enable kv_cache (default: False)" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + prompt = "What is parallel programming ?" + model_inputs = tokenizer(prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + + # Prepare input prompt + # word = "What" + # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence + # input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + MAX_TOKENS + # Pyt + pytorch_input_signature = (input_ids.clone(),) + if args.disable_pytorch_run: + pyt_gen_tokens = None + else: + pyt_gen_tokens = generate( + model, pytorch_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + + pyt_timings = time_generate( + generate, + model, + pytorch_input_signature, + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None + ) + + # TRT + if args.kv_cache: + # This import is required to register static/dynamic KV cache transformations as lowering passes + import torch_tensorrt.extensions + + trt_model = compile_torchtrt(model, input_ids, min_block_size=args.min_block_size, debug=args.debug) + + if args.kv_cache: + trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + trt_gen_tokens = generate_with_kv_cache( + trt_model, trt_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + trt_timings = time_generate( + generate_with_kv_cache, + trt_model, + trt_input_signature, + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + else: + trt_gen_tokens = generate( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + trt_stats = recordStats( + "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None + ) + + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + print("===================== \n") + if not args.disable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/examples/dynamo/llama_benchmark.py b/examples/dynamo/llama_benchmark.py new file mode 100644 index 0000000000..d08c477456 --- /dev/null +++ b/examples/dynamo/llama_benchmark.py @@ -0,0 +1,77 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import timeit + +USE_CACHE = True +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MAX_NEW_TOKENS = 128 + + +def main(): + # Initialize model and tokenizer + print("Loading model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.float16, + use_cache=False, + device_map="auto" + ) + model.generation_config.cache_implementation = "static" + model.forward = torch.compile(model.forward) + + # Prepare input prompt + word = "What" + # Tokenize the word + word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence + # Repeat the token 2048 times + input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + print(f"Input tensor shape: {input_ids.shape}") + + # # Warm-up pass + print("Running warm-up pass...") + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + + # Benchmark loop + print("Running benchmark...") + num_iterations = 10 + total_time = 0 + timings = [] + + for i in range(num_iterations): + start_time = timeit.default_timer() + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + end_time = timeit.default_timer() + generation_time = end_time - start_time + total_time += generation_time + timings.append(generation_time) + + # Decode and print first iteration output + # if i == 0: + # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # print("\nFirst generation output:") + # print(output_text) + + # Calculate and print statistics + average_time = total_time / num_iterations + print(f"\nPerformance Statistics:") + print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") + print(f"Average tokens per second: {100/average_time:.2f}") + print("\nIndividual timings (ms):") + for i, t in enumerate(timings): + print(f"Iteration {i+1}: {t*1000:.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 496eadb5c9..61fc9ada7a 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -39,31 +39,54 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): return ep -def generate(model, input_seq, max_tokens, eos_token_id): +def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): """ Greedy decoding of the model. This generates up to max_tokens. """ - # Max length of output seq = current input_seq length + max_tokens allowed to generate - max_output_seq_length = input_seq.shape[1] + max_tokens stopping_criteria = StoppingCriteriaList( [ MaxLengthCriteria(max_length=max_output_seq_length), EosTokenCriteria(eos_token_id=eos_token_id), ] ) - - while True: + isl = input_seq.shape[1] + osl = max_output_seq_length - isl + num_tokens_generated = 0 + while num_tokens_generated < osl: outputs = model(input_seq) logits = outputs.logits next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) - # TODO: Handle batch in this check - if stopping_criteria(input_seq, logits).item(): + num_tokens_generated += 1 + # # TODO: Handle batch in this check + if not benchmark and stopping_criteria(input_seq, logits).item(): break - + return input_seq +def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with KV cache. + """ + start_idx = 0 + end_idx = input_signature[0].shape[1] + output_seq = input_signature[0].clone() + + # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL + while end_idx < max_output_seq_length: + input_signature_with_start_end_idx = input_signature + (start_idx, end_idx) + logits_keys_values = model(*input_signature_with_start_end_idx) + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + input_signature = (next_tokens, *kv_cache) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + start_idx = end_idx + end_idx = start_idx + 1 + + return output_seq def time_generate( generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 @@ -74,9 +97,8 @@ def time_generate( timings = [] for _ in range(iterations): start_time = timeit.default_timer() - inputs_copy = copy.copy(inputs) _ = generate_fn( - model, inputs_copy, output_seq_length, eos_token_id + model, inputs, output_seq_length, eos_token_id ) torch.cuda.synchronize() end_time = timeit.default_timer() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index acd16a32f0..39d5ce7464 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -684,6 +684,7 @@ def compile( ) gm = exported_program.module() + exported_program.module().to("cpu") logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module @@ -769,6 +770,30 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) + + # Store the original input spec for later use + original_in_spec = getattr(gm, '_in_spec', None) + original_out_spec = getattr(gm, '_out_spec', None) + + # Function to preserve and restore module specs + def preserve_module_specs(in_spec, out_spec, target_module): + """ + Applies input and output specs to the target module. + + Args: + in_spec: The input spec to apply + out_spec: The output spec to apply + target_module: The module to apply specs to + """ + # Apply specs to target module + if in_spec is not None: + target_module._in_spec = in_spec + if out_spec is not None: + target_module._out_spec = out_spec + + return target_module + + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -816,6 +841,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: continue submodule_node_dict[node.name] = node + preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..fe4b781505 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -743,7 +743,11 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) + tensor_inputs = [ + t if isinstance(t, torch.Tensor) else torch.tensor(t) + for t in inputs + ] + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e4018ae95c..d59323a5d8 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) - """ Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64 """ - if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)): + if isinstance(tensor, (torch.Tensor, FakeTensor)): + return tensor.dtype + elif isinstance(tensor, (int, float, bool)): return torch.tensor(tensor).dtype elif isinstance(tensor, torch.SymInt): return torch.int64 diff --git a/py/torch_tensorrt/extensions/__init__.py b/py/torch_tensorrt/extensions/__init__.py new file mode 100644 index 0000000000..80112f1526 --- /dev/null +++ b/py/torch_tensorrt/extensions/__init__.py @@ -0,0 +1 @@ +from . import hf \ No newline at end of file diff --git a/py/torch_tensorrt/extensions/hf/__init__.py b/py/torch_tensorrt/extensions/hf/__init__.py new file mode 100644 index 0000000000..19d2a493ea --- /dev/null +++ b/py/torch_tensorrt/extensions/hf/__init__.py @@ -0,0 +1,2 @@ +from .static_cache import * +# from .dynamic_cache import * \ No newline at end of file diff --git a/py/torch_tensorrt/extensions/hf/dynamic_cache.py b/py/torch_tensorrt/extensions/hf/dynamic_cache.py new file mode 100644 index 0000000000..b7348157ec --- /dev/null +++ b/py/torch_tensorrt/extensions/hf/dynamic_cache.py @@ -0,0 +1,249 @@ +import logging +from typing import Dict, List, Tuple, Union, Sequence, Any + +import torch +from torch.fx.node import Target + +import torch_tensorrt +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +from .utils import add_graph_input, create_random_output_tensors, get_kv_nodes +import tensorrt +import torch.utils._pytree as pytree +logger = logging.getLogger(__name__) + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output + +def add_kv_as_outputs(gm): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + # list of MHA kernels we would want to detect and replace + mha_ops = { + torch._C._nn.scaled_dot_product_attention, + } + + # Find all SDPA nodes in the graph + mha_nodes = [] + for node in gm.graph.nodes: + if is_op(node, mha_ops): + mha_nodes.append(node) + + # Iterate through each MHA node to extract shape information + for mha_node in mha_nodes: + if "val" in mha_node.meta and len(mha_node.args) >= 3: + # Get the input nodes (query, key, value) + q_node, k_node, v_node = mha_node.args[:3] + + # Add the copy nodes as outputs to the graph + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + ((k_node, v_node),) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs, (k_node, v_node)) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + + + +def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors and index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx") + end_idx_input = add_graph_input(gm, "end_idx") + return kv_inputs, start_idx_input, end_idx_input + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + pass + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + + for idx, sdpa_node in enumerate(sdpa_nodes): + + +def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Insert a torch.cond operation before each scaled_dot_product_attention operation. + + Args: + gm: The FX GraphModule to modify + + Returns: + The modified GraphModule + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + + # For each SDPA node, insert a torch.cond operation before it + for idx, sdpa_node in enumerate(sdpa_nodes): + + with gm.graph.inserting_before(sdpa_node): + pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + # Create nodes for concatenating k with incoming_key and v with incoming_value + concatenated_k_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([k_node, incoming_key], 2), # Concatenate along sequence length dimension + kwargs={} + ) + concatenated_v_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([v_node, incoming_value], 2), # Concatenate along sequence length dimension + kwargs={} + ) + + # Create the torch.cond node + cond_k_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(pred_node, concatenated_k_node, k_node), + ) + + cond_v_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(pred_node, concatenated_v_node, v_node), + ) + + sdpa_node.args = (q_node, cond_k_node, cond_v_node) + + return gm + + + +@_aten_lowering_pass +def insert_dynamic_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert FlashInfer MHA + KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + + # Call the function to add QKV as outputs + logits_keys_values = add_kv_as_outputs(gm, start_idx_input, end_idx_input) + + gm = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + # gm = insert_torch_cond_before_sdpa(gm, kv_inputs) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm + + diff --git a/py/torch_tensorrt/extensions/hf/static_cache.py b/py/torch_tensorrt/extensions/hf/static_cache.py new file mode 100644 index 0000000000..55e10d0377 --- /dev/null +++ b/py/torch_tensorrt/extensions/hf/static_cache.py @@ -0,0 +1,258 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from .utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors and index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get input_ids metadata from the graph. The first placeholder node in the graph is the input_ids node + # find the first placeholder node, then pull out meta["val"] + placeholder_node = next(node for node in gm.graph.nodes if node.op == "placeholder") + input_ids_meta = placeholder_node.meta["val"] + seq_len = input_ids_meta.shape[1] + + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + shape_env = seq_len.node.shape_env + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + pass + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(cat,), + kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={} + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={} + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={} + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + + # Call the function to add QKV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm + + diff --git a/py/torch_tensorrt/extensions/hf/utils.py b/py/torch_tensorrt/extensions/hf/utils.py new file mode 100644 index 0000000000..7a17ad7e65 --- /dev/null +++ b/py/torch_tensorrt/extensions/hf/utils.py @@ -0,0 +1,152 @@ +import torch +from torch.fx import Graph, GraphModule, Node +from typing import Optional, Union, Iterable, List, Tuple +from torch._ops import OpOverloadPacket +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._pytree import _LEAF_SPEC +from torch._export.utils import _detect_fake_mode_from_gm + +def get_kv_nodes(gm): + """ + Get the key and value nodes from the graph. + """ + kv_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + q_node, k_node, v_node = node.args[:3] + kv_nodes.append((k_node, v_node)) + return kv_nodes + +def get_random_tensor_from_node(node: Node) -> torch.Tensor: + """ + Creates a random tensor based on the shape information in a node's metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + node: A torch.fx.Node object with metadata containing tensor information + + Returns: + A random tensor with shape matching the node's metadata, or None if no valid + tensor information is found + """ + if "val" not in node.meta: + raise ValueError(f"No tensor information found in node metadata for node: {node}") + + fake_tensor = node.meta["val"] + shape = [] + + # Iterate through each dimension and handle symbolic dimensions + for dim in fake_tensor.shape: + if isinstance(dim, torch.SymInt): + # Extract the maximum value from the shape environment + max_val = dim.node.hint + shape.append(max_val) + else: + shape.append(dim) + + # Create a random tensor with the determined shape + dtype = fake_tensor.dtype + device = fake_tensor.device + random_tensor = torch.rand(shape, dtype=dtype, device=device) + + return random_tensor + +def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: + """ + Creates random tensors based on the shape information in node metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + nodes: List of torch.fx.Node objects with metadata + + Returns: + List of random tensors with shapes matching the nodes' metadata + """ + random_tensors = [] + + for node in nodes: + if isinstance(node, Node): + node_tensor = get_random_tensor_from_node(node) + elif isinstance(node, tuple): + node_tensor_list = [] + for n in node: + random_tensor = get_random_tensor_from_node(n) + node_tensor_list.append(random_tensor) + node_tensor = tuple(node_tensor_list) + + random_tensors.append(node_tensor) + + return random_tensors + +def add_graph_input( + gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None +) -> Node: + """Add a graph input to the given GraphModule and return the newly created node. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the input to. + name (str): The name of the input. + val (torch.Tensor): An example tensor to use for the input. + dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] + """ + # check that no dynamic shape is provided... + if dynamic_shape: + raise NotImplementedError("Dynamic shape not supported for adding graph inputs") + + # extract graph and input spec + graph: Graph = gm.graph + + in_spec = graph._codegen.pytree_info.in_spec + in_spec_for_args = in_spec.children_specs[0] + orig_args = graph._codegen.pytree_info.orig_args + assert in_spec_for_args.type is tuple + + # insert input node after currently last input node + node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] + with graph.inserting_after(node_last_input): + in_node = graph.placeholder(name) + in_spec_for_args.children_specs.append(_LEAF_SPEC) + orig_args.append(f"arg_{name}") + + # update pytree info recursively with __post_init__ starting at leaves + def call_post_init(spec): + for child_spec in spec.children_specs: + call_post_init(child_spec) + spec.__post_init__() + + call_post_init(in_spec) + + # set fake tensor information if all required information is available + fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) + if fake_mode and val is not None and isinstance(val, torch.Tensor): + if isinstance(val, FakeTensor): + fake_tensor = val + else: + fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) + in_node.meta["val"] = fake_tensor + in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) + + # return new node... + return in_node + +def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: + """Check if the node is a call to one of the ops.""" + if node.op != "call_function": + return False + # check if it's a single op that's provided + if isinstance(ops, OpOverloadPacket): + ops = [ops] + + # check if it's the op itself instead of an overload + if any(node.target == op for op in ops): + return True + + return False + +def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: + input_nodes: List[Node] = graph.find_nodes(op="placeholder") + output_nodes: List[Node] = graph.find_nodes(op="output") + return (input_nodes, output_nodes) \ No newline at end of file From a3a202f72a9f4e80d90f90fa3b5b65ed3c7de83b Mon Sep 17 00:00:00 2001 From: Chengzhe Xu Date: Tue, 13 May 2025 23:15:40 +0000 Subject: [PATCH 04/30] chore: updates --- examples/dynamo/llama3_trt.py | 124 ++- examples/dynamo/test_if.py | 86 ++ examples/dynamo/utils.py | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 4 + .../dynamo/runtime/_PythonCUDAGraphModule.py | 771 ++++++++++++++++++ 5 files changed, 957 insertions(+), 35 deletions(-) create mode 100644 examples/dynamo/test_if.py create mode 100644 py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py index 79281ffcf1..857113ad03 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llama3_trt.py @@ -21,7 +21,7 @@ from contextlib import nullcontext from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache -MAX_TOKENS = 128 + DEVICE = torch.device("cuda:0") def get_model(args): @@ -34,7 +34,6 @@ def get_model(args): attn_implementation="sdpa", ) .eval() - .half() .cuda() ) @@ -47,7 +46,6 @@ def get_model(args): # num_hidden_layers=2 ) .eval() - .half() .cuda() ) elif args.model == "meta-llama/Llama-3.1-8B-Instruct": @@ -58,37 +56,58 @@ def get_model(args): attn_implementation="sdpa", # num_hidden_layers=1 ) .eval() - .half() .cuda() ) elif args.model == "google/gemma-3-1b-it": model = ( AutoModelForCausalLM.from_pretrained( - "google/gemma-3-1b-it", use_cache=False, attn_implementation="sdpa" + "google/gemma-3-1b-it", + use_cache=False, + attn_implementation="sdpa" ) .eval() - .half() .cuda() ) - model = model.to(torch.float16) + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + return model -def compile_torchtrt(model, input_ids, min_block_size=1, debug=False): - max_seq_len = input_ids.shape[1] + MAX_TOKENS +def compile_torchtrt(model, input_ids, args): + max_seq_len = input_ids.shape[1] + args.max_tokens ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - - with (torch_tensorrt.logging.debug() if debug else nullcontext()): + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile( ep, inputs=[input_ids], - enabled_precisions={torch.float16}, + enabled_precisions=enabled_precisions, # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, device=DEVICE, disable_tf32=True, use_python_runtime=True, - debug=debug, - min_block_size=min_block_size, + debug=args.debug, + min_block_size=args.min_block_size, ) return trt_model @@ -96,8 +115,6 @@ def compile_torchtrt(model, input_ids, min_block_size=1, debug=False): def print_outputs(backend_name, gen_tokens, tokenizer): - - print(f"============================= {backend_name} ==============================") print( f"{backend_name} model generated text: ", tokenizer.decode(gen_tokens[0], skip_special_tokens=True), @@ -127,6 +144,33 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): return tuple(zeroed_kv_cache_inputs) +def measure_perf(trt_model, input_signature, backend_name): + # Measure average time for 10 iterations + import timeit + import numpy as np + + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + # print(f"Iteration {i+1}: {iter_time:.4f} seconds") + + avg_time = total_time / iterations + print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds") + print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second") + if __name__ == "__main__": arg_parser = argparse.ArgumentParser( description="Run inference on a model with random input values" @@ -151,15 +195,23 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): "--min_block_size", type=int, default=1, help="no. of iterations to run" ) arg_parser.add_argument( - "--disable_pytorch_run", - action="store_false", - help="Disable pytorch run (default: True)" + "--max_tokens", type=int, default=128, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)" ) arg_parser.add_argument( "--kv_cache", action="store_true", help="Enable kv_cache (default: False)" ) + arg_parser.add_argument( + "--cudagraph", + action="store_true", + help="Enable cudagraphs (default: False)" + ) arg_parser.add_argument( "--debug", action="store_true", @@ -180,20 +232,20 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence # input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + MAX_TOKENS + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.max_tokens # Pyt - pytorch_input_signature = (input_ids.clone(),) - if args.disable_pytorch_run: - pyt_gen_tokens = None - else: + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + if args.enable_pytorch_run: pyt_gen_tokens = generate( - model, pytorch_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id ) pyt_timings = time_generate( generate, model, - pytorch_input_signature, + input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, iterations=args.iterations, @@ -207,13 +259,21 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): # This import is required to register static/dynamic KV cache transformations as lowering passes import torch_tensorrt.extensions - trt_model = compile_torchtrt(model, input_ids, min_block_size=args.min_block_size, debug=args.debug) - + trt_model = compile_torchtrt(model, input_ids, args) + if args.kv_cache: + trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) trt_gen_tokens = generate_with_kv_cache( trt_model, trt_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - ) + ) + + trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) trt_timings = time_generate( generate_with_kv_cache, trt_model, @@ -239,10 +299,10 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None ) - + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) print_outputs("TensorRT", trt_gen_tokens, tokenizer) - print("===================== \n") - if not args.disable_pytorch_run: + if args.enable_pytorch_run: print("=========PyTorch PERFORMANCE============ \n") print(pyt_stats) print("===================== \n") diff --git a/examples/dynamo/test_if.py b/examples/dynamo/test_if.py new file mode 100644 index 0000000000..ca573f330c --- /dev/null +++ b/examples/dynamo/test_if.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from torch.export import export + +class ConditionalModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, flag): + def true_fn(q, k, v, k1, v1): + k_new = torch.cat((k, k1), dim=2) + v_new = torch.cat((v, v1), dim=2) + return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) + + def false_fn(q, k, v, k1, v1): + return torch._C._nn.scaled_dot_product_attention(q, k, v) + + out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) + + return 2 * out + +class ConditionalModel2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, start_idx, end_idx): + + new_k1 = torch.cat((k1[:, :, :start_idx, :], k, k1[:, :, end_idx:, :]), dim=2) + new_v1 = torch.cat((v1[:, :, :start_idx, :], v, v1[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, new_k1[:,:,:end_idx,:], new_v1[:,:,:end_idx,:]) + + return out, new_k1, new_v1 + + +def main(): + # Create model + model = ConditionalModel2() + model.eval() # Set to evaluation mode + + # Create example inputs + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + k1 = torch.zeros(1, 32, 2176, 64).cuda() + v1 = torch.zeros(1, 32, 2176, 64).cuda() + # example_flag = torch.tensor(True) + start_idx = 0 + end_idx = 2048 + out_pyt = model(q, k, v, k1, v1, start_idx, end_idx) + out_pyt2 = model(q, k, v, k1, v1, 17, 18) + + exported_program = export( + model, + args=(q, k, v, k1, v1, start_idx, end_idx), + dynamic_shapes=(None, None, None, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), + strict=False + ) + import torch_tensorrt + with torch_tensorrt.logging.debug(): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v, k1, v1, start_idx, end_idx], + enabled_precisions={torch.float32}, + # truncate_double=True, + disable_tf32=True, + use_python_runtime=True, + debug=True, + min_block_size=1, + ) + + gm = exported_program.module() + breakpoint() + out_ep = gm(q, k, v, k1, v1, start_idx, end_idx) + out_ep2 = gm(q, k, v, k1, v1, 2048, 2049) + out_trt = trt_model(q, k, v, k1, v1, start_idx, end_idx) + out_trt2 = trt_model(q, k, v, k1, v1, 2048, 2049) + # breakpoint() + # Print the graph + print("\nExported Graph:") + print(exported_program.graph_module.graph) + # breakpoint() + # print("done") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 61fc9ada7a..7379e09ba9 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -72,9 +72,10 @@ def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_to start_idx = 0 end_idx = input_signature[0].shape[1] output_seq = input_signature[0].clone() - + isl = input_signature[0].shape[1] # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL - while end_idx < max_output_seq_length: + num_tokens_generated = 0 + while num_tokens_generated <= max_output_seq_length - isl: # end_idx < max_output_seq_length: input_signature_with_start_end_idx = input_signature + (start_idx, end_idx) logits_keys_values = model(*input_signature_with_start_end_idx) logits = logits_keys_values[0] @@ -85,7 +86,7 @@ def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_to output_seq = torch.cat([output_seq, next_tokens], dim=-1) start_idx = end_idx end_idx = start_idx + 1 - + num_tokens_generated += 1 return output_seq def time_generate( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 39d5ce7464..b95ae9e3cb 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -684,7 +684,11 @@ def compile( ) gm = exported_program.module() + exported_program.module().to("cpu") + torch.cuda.empty_cache() + import gc + gc.collect() logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py new file mode 100644 index 0000000000..9aac192316 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py @@ -0,0 +1,771 @@ +from __future__ import annotations + +import logging +from contextlib import nullcontext +from tempfile import tempdir +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import tensorrt as trt +import torch +import torch_tensorrt +from torch.nn import Module +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import Platform, dtype +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._utils import ( + _is_switch_required, + _select_rt_device, + multi_gpu_device_check, +) + +logger = logging.getLogger(__name__) + + +class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] + def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: + trt.IOutputAllocator.__init__(self) + self.buffers: Dict[str, torch.Tensor] = {} + self.shapes: Dict[str, Tuple[int, ...]] = {} + self.dtypes: Dict[str, torch.dtype] = output_dtypes + + def reallocate_output_async( + self, + tensor_name: str, + memory: int, + size: int, + alignment: int, + stream: torch.cuda.Stream, + ) -> Any: + shape = (size,) + if tensor_name not in self.buffers: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + else: + if self.buffers[tensor_name].shape != shape: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + return self.buffers[tensor_name].data_ptr() + + def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: + self.shapes[tensor_name] = tuple(shape) + + +class TorchTRTRuntimeStates: + def __init__(self, new_cudagraphs: bool): + # Indicates whether CUDAGraphs were enabled in the previous execute_engine + self.old_cudagraphs = new_cudagraphs + # Indicates whether pre-allocated output was enabled in the previous execute_engine + self.old_pre_allocated_outputs = False + # Indicates whether context has changed + self.context_changed = False + + def set_runtime_states( + self, + new_cudagraphs: bool, + new_pre_allocated_output: bool, + shape_changed: bool, + ) -> Tuple[bool, bool, bool]: + # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs + # based on the current and previous states, as well as input shape has changed + need_cudagraphs_record = False + can_use_pre_allocated_outputs = False + need_cudagraphs_reset = False + + # CUDA Graph recording is needed if CUDA graphs is enabled and: + # - CUDA graphs were previously disabled + # - or the shape has changed + # - or the execution context has changed (e.g., weight streaming) + if new_cudagraphs and ( + not self.old_cudagraphs or shape_changed or self.context_changed + ): + need_cudagraphs_record = True + + # Pre-allocated output can be used when previous and current state are true without shape change + if ( + self.old_pre_allocated_outputs + and new_pre_allocated_output + and (not shape_changed) + ): + can_use_pre_allocated_outputs = True + + if not new_cudagraphs or shape_changed or self.context_changed: + need_cudagraphs_reset = True + + self.old_cudagraphs = new_cudagraphs + self.old_pre_allocated_outputs = new_pre_allocated_output + # reset flag + self.context_changed = False + + return ( + need_cudagraphs_record, + can_use_pre_allocated_outputs, + need_cudagraphs_reset, + ) + + +class PythonTorchTensorRTModule(Module): # type: ignore[misc] + """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. + + This module is backed by the Torch-TensorRT runtime and is only compatible with + FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment. + """ + + def __init__( + self, + serialized_engine: Optional[bytes] = None, + input_binding_names: Optional[List[str]] = None, + output_binding_names: Optional[List[str]] = None, + *, + name: str = "", + settings: CompilationSettings = CompilationSettings(), + weight_name_map: Optional[dict[Any, Any]] = None, + requires_output_allocator: bool = False, + ): + """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs + a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine + + Arguments: + serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray + input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules + output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned + + Keyword Arguments: + name (str): Name for module + settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + + Example: + + .. code-block:: py + + trt_module = PythonTorchTensorRTModule( + engine_str, + input_binding_names=["x"], + output_binding_names=["output"], + name="my_module", + settings=CompilationSettings(device=torch.cuda.current_device) + ) + + """ + self.context: Any + super(PythonTorchTensorRTModule, self).__init__() + self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + + self.name = name + self._input_buffers: Dict[str, List[torch.Tensor]] = {} + self._output_buffers: Dict[str, List[torch.Tensor]] = {} + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self._caller_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: Optional[torch.cuda.Stream] = None + + # TODO: Make the below a Dictionary {shape: cudagraph} + self.shape_key_to_cudagraph: Dict[str, torch.cuda.CUDAGraph] = {} + + # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 + # Unused currently - to be used by Dynamic Shape support implementation + self.memory_pool = None + + self.serialized_engine = serialized_engine + self.input_names = ( + input_binding_names if input_binding_names is not None else [] + ) + self.output_names = ( + output_binding_names if output_binding_names is not None else [] + ) + self.initialized = False + self.target_device_id = ( + settings.device.gpu_id + if settings.device is not None + else Device._current_device().gpu_id + ) + self.target_device_properties = torch.cuda.get_device_properties( + self.target_device_id + ) + self.profiling_enabled = settings.debug if settings.debug is not None else False + self.settings = settings + self.engine = None + self.weight_name_map = weight_name_map + self.target_platform = Platform.current_platform() + self.runtime_states = TorchTRTRuntimeStates( + torch_tensorrt.runtime.get_cudagraphs_mode() + ) + + self.cudagraphs_enabled = False + self.pre_allocated_outputs: List[torch.Tensor] = [] + self.use_pre_allocated_outputs = False + + self.requires_output_allocator = requires_output_allocator + self.output_allocator: Optional[DynamicOutputAllocator] = None + self.use_output_allocator_outputs = False + + if self.serialized_engine is not None and not self.settings.lazy_engine_init: + self.setup_engine() + + def get_streamable_device_memory_budget(self) -> Any: + return self.engine.streamable_weights_size + + def get_automatic_device_memory_budget(self) -> Any: + return self.engine.get_weight_streaming_automatic_budget() + + def get_device_memory_budget(self) -> Any: + return self.engine.weight_streaming_budget_v2 + + def set_device_memory_budget(self, budget_bytes: int) -> int: + # Recreating the context because weight streaming budget cannot be modified while there are active context. + if self.context is not None: + del self.context + budget_bytes = self._set_device_memory_budget(budget_bytes) + self.context = self.engine.create_execution_context() + self.runtime_states.context_changed = True + return budget_bytes + + def _set_device_memory_budget(self, budget_bytes: int) -> int: + # Disable weight streaming for invalid budget size + if budget_bytes < 0: + budget_bytes = self.get_streamable_device_memory_budget() + self.engine.weight_streaming_budget_v2 = budget_bytes + if self.engine.weight_streaming_budget_v2 != budget_bytes: + logger.error(f"Failed to set weight streaming budget to {budget_bytes}") + budget_bytes = self.engine.weight_streaming_budget_v2 + if self.get_streamable_device_memory_budget() == budget_bytes: + logger.warning("Weight streaming is disabled") + + return budget_bytes + + def set_default_device_memory_budget(self) -> int: + budget_bytes = self.get_automatic_device_memory_budget() + # Set automatic weight streaming budget as default when context is created + logger.debug(f"Weight streaming budget set to {budget_bytes}B") + return self._set_device_memory_budget(budget_bytes) + + def setup_engine(self) -> None: + assert ( + self.target_platform == Platform.current_platform() + ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + + self.initialized = True + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) + if self.settings.enable_weight_streaming: + self.set_default_device_memory_budget() + self.context = self.engine.create_execution_context() + assert self.engine.num_io_tensors == ( + len(self.input_names) + len(self.output_names) + ) + + self.input_dtypes = [ + dtype._from(self.engine.get_tensor_dtype(input_name)) + for input_name in self.input_names + ] + self.input_shapes = [ + self.engine.get_tensor_shape(input_name) for input_name in self.input_names + ] + self.output_dtypes = [ + dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) + for output_name in self.output_names + ] + self.output_shapes = [ + self.engine.get_tensor_shape(output_name) + for output_name in self.output_names + ] + + if self.requires_output_allocator: + self.create_output_allocator() + + if torch_tensorrt.runtime.get_cudagraphs_mode(): + self.cudagraph = torch.cuda.CUDAGraph() + + def _check_initialized(self) -> None: + if not self.initialized: + raise RuntimeError("PythonTorchTensorRTModule is not initialized.") + + def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: + state_dict[prefix + "engine"] = self.serialized_engine + state_dict[prefix + "input_names"] = self.input_names + state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "platform"] = self.target_platform + + def _load_from_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Any, + strict: Any, + missing_keys: Any, + unexpected_keys: Any, + error_msgs: Any, + ) -> None: + self.serialized_engine = state_dict[prefix + "engine"] + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] + self.target_platform = state_dict[prefix + "platform"] + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + self.setup_engine() + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state.pop("engine", None) + state.pop("context", None) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self.setup_engine() + + def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + result.__setstate__(self.__getstate__()) + return result + + def _reset_captured_graph(self, inputs_shape_key: str = None) -> None: + if inputs_shape_key in self.shape_key_to_cudagraph: + self.shape_key_to_cudagraph[inputs_shape_key].reset() + self.shape_key_to_cudagraph.pop(inputs_shape_key) + + def __del__(self) -> None: + self._reset_captured_graph() + + def setup_input_tensors( + self, + contiguous_inputs: List[torch.Tensor], + cudagraphs_enabled: bool, + need_cudagraphs_record: bool, + inputs_shape_key: str = None, + ) -> None: + for i, input_name in enumerate(self.input_names): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + + assert ( + contiguous_inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." + + is_shape_tensor_input = self.engine.is_shape_inference_io(input_name) + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + if is_shape_tensor_input: + self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone() + else: + self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone() + + # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # as per TensorRT requirements + if is_shape_tensor_input: + # Shape tensor inputs are casted to int64 explicitly + # Currently Torch CPU pointers are not working; numpy pointers are used instead + # to refer to underlying memory + inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64) + inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() + # if cudagraphs_enabled: + # self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu) + # self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data) + # else: + self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data) + else: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) + if cudagraphs_enabled: + self._input_buffers[inputs_shape_key][i].copy_(contiguous_inputs[i]) + self.context.set_tensor_address( + input_name, self._input_buffers[inputs_shape_key][i].data_ptr() + ) + else: + self.context.set_tensor_address( + input_name, contiguous_inputs[i].data_ptr() + ) + + def create_output_tensors(self) -> List[torch.Tensor]: + # create output tensors + outputs: List[torch.Tensor] = [] + + for o, _ in enumerate(self.output_names): + output = torch.empty( + size=self.output_shapes[o], + dtype=self.output_dtypes[o], + device=torch.cuda.current_device(), + ) + outputs.append(output) + return outputs + + def set_pre_allocated_outputs(self, enable: bool) -> None: + self.use_pre_allocated_outputs = enable + + def set_use_output_allocator(self, enable: bool) -> None: + self.use_output_allocator_outputs = enable + + def create_output_allocator(self) -> None: + if self.output_allocator is None: + output_dtypes_dict = {} + for o, output_name in enumerate(self.output_names): + output_dtypes_dict[output_name] = self.output_dtypes[o] + self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) + + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + + def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: + # print(f"**************** first key cache shape: {inputs[1].shape}") + shape_changed, inputs_shape_key = self.validate_input_shapes(inputs) + ( + need_cudagraphs_record, + can_use_pre_allocated_outputs, + need_cudagraphs_reset, + ) = self.runtime_states.set_runtime_states( + self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed + ) + + if need_cudagraphs_reset: + self._reset_captured_graph(inputs_shape_key) + + if need_cudagraphs_record: + self._input_buffers[inputs_shape_key] = [None] * len(self.input_names) + self._output_buffers[inputs_shape_key] = [None] * len(self.output_names) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + self.setup_input_tensors( + contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key + ) + + if shape_changed: + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + if can_use_pre_allocated_outputs: + outputs = self.pre_allocated_outputs + else: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] + if DYNAMIC_DIM in self.output_shapes: + raise ValueError( + "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." + ) + outputs = self.create_output_tensors() + + for o, output_name in enumerate(self.output_names): + if need_cudagraphs_record: + self._output_buffers[inputs_shape_key][o] = outputs[o].clone() + + if self.cudagraphs_enabled: + self.context.set_tensor_address( + output_name, self._output_buffers[inputs_shape_key][o].data_ptr() + ) + else: + self.context.set_tensor_address( + output_name, outputs[o].data_ptr() + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + if self.cudagraphs_enabled: + if need_cudagraphs_record: + + self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph() + + if self.profiling_enabled: + self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode() + + with torch.cuda.graph( + self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) + + if self.profiling_enabled: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + self.shape_key_to_cudagraph[inputs_shape_key].debug_dump( + f"{tempdir}/{self.name}_cudagraph.dot" + ) + + self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore + + else: + self.context.execute_async_v3(self._engine_stream.cuda_stream) + + self._caller_stream.wait_stream(self._engine_stream) + + if self.use_pre_allocated_outputs: + self.pre_allocated_outputs = self.create_output_tensors() + + if self.cudagraphs_enabled: + for idx, o in enumerate(outputs): + o.copy_(self._output_buffers[inputs_shape_key][idx]) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: + assert ( + not torch_tensorrt.runtime.get_cudagraphs_mode() + ), "CUDA Graphs are not compatible with OutputAllocator." + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + self.setup_input_tensors(contiguous_inputs, False, False) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:SetupOutputAllocator" + ) + if self.profiling_enabled + else nullcontext() + ): + self.create_output_allocator() + # need to set output allocator every run + for output_name in self.output_names: + if not self.context.set_output_allocator( + output_name, self.output_allocator + ): + raise RuntimeError( + f"Failed to set output allocator for {output_name}" + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) # The OutputAllocator is called by execute_async_v3() + + self._caller_stream.wait_stream(self._engine_stream) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + outputs = [] + assert self.output_allocator is not None + for o, output_name in enumerate(self.output_names): + shape = self.output_allocator.shapes.get(output_name, None) + dtype = self.output_dtypes[o] + output = ( + self.output_allocator.buffers.get(output_name, None) + .clone() + .detach() + ) + prod = int(torch.prod(torch.tensor(shape))) + # When using the OutputAllocator, the allocated buffer might be larger than the size of the output, + # so we need to reshape the buffer to the output shape + output = output.reshape(-1).view(dtype)[:prod].reshape(shape) + outputs.append(output) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + + # Run forward function + contiguous_inputs: List[torch.Tensor] = [ + (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) + for i in inputs + ] + with ( + torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") + if self.profiling_enabled + else nullcontext() + ): + self._check_initialized() + + # If in safe mode, check at each iteration for whether a switch is required + if ( + torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, + ) + + # Update current device + device = torch.device(device_id) + torch.cuda.set_device(device_id) + + contiguous_inputs = [ + tensor.to(device) for tensor in contiguous_inputs + ] + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + + if self.requires_output_allocator: # engine requires OA + if self.cudagraphs_enabled: + raise RuntimeError( + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." + ) + logger.debug("Using the dynamic allocator runtime mode.") + return run_output_allocator() + else: + if self.use_output_allocator_outputs: # users call OA context manager + if self.cudagraphs_enabled: + raise RuntimeError( + "Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two." + ) + logger.debug("Using the dynamic allocator runtime mode.") + return run_output_allocator() + else: + logger.debug( + f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." + ) + return run_standard_execution() + + def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: + """ + Enable TensorRT profiling. After calling this function, TensorRT will report + time spent on each layer in stdout for each forward run. + """ + self._check_initialized() + + if not self.context.profiler: + self.context.profiler = trt.Profiler() if profiler is None else profiler + + self.profiling_enabled = True + + def disable_profiling(self) -> None: + """ + Disable TensorRT profiling. + """ + self._check_initialized() + torch.cuda.synchronize() + del self.context + self.context = self.engine.create_execution_context() + self.profiling_enabled = False + + def get_layer_info(self) -> str: + """ + Get layer info of the engine. Only support for TRT > 8.2. + """ + inspector = self.engine.create_engine_inspector() + engine_json: str = inspector.get_engine_information( + trt.LayerInformationFormat.JSON + ) + return engine_json + + def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + """ + Validates the input shapes of the forward function has changed + """ + # Representation of input shapes to a given model + # Shapes are concatenated as so: + # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) + tensor_inputs = [ + t if isinstance(t, torch.Tensor) else torch.tensor(t) + for t in inputs + ] + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) + + # If the new shape key differs from the existing one, + # invalidate the old shape key and remove the CUDAGraph + if new_shape_key not in self.shape_key_to_cudagraph: + logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.") + # self.shape_key = new_shape_key + return True, new_shape_key + + return False, new_shape_key From 36426148b2828e7c14cdb5fd846f23032dbf90f9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 14 May 2025 23:18:49 +0000 Subject: [PATCH 05/30] chore: updates --- examples/dynamo/llama3_trt.py | 27 ++++++++++++++++++++++----- examples/dynamo/utils.py | 13 +++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py index 857113ad03..30d086757d 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llama3_trt.py @@ -26,12 +26,24 @@ def get_model(args): with torch.no_grad(): - if args.model == "meta-llama/Llama-3.2-1B-Instruct": + if args.model == "meta-llama/Llama-2-7b-hf": model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, attn_implementation="sdpa", + # num_hidden_layers=1 + ) + .eval() + .cuda() + ) + elif args.model == "meta-llama/Llama-3.2-1B-Instruct": + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1 ) .eval() .cuda() @@ -94,7 +106,7 @@ def compile_torchtrt(model, input_ids, args): use_fp32_acc = False else: enabled_precisions = {torch.float32} - + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile( ep, @@ -114,12 +126,12 @@ def compile_torchtrt(model, input_ids, args): def print_outputs(backend_name, gen_tokens, tokenizer): - + print(f"========= {backend_name} =========") print( f"{backend_name} model generated text: ", tokenizer.decode(gen_tokens[0], skip_special_tokens=True), ) - print("=============================") + print("===================================") def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): """ @@ -260,7 +272,11 @@ def measure_perf(trt_model, input_signature, backend_name): import torch_tensorrt.extensions trt_model = compile_torchtrt(model, input_ids, args) - + pyt_logits = model.cuda()(input_ids.clone()) + trt_logits = trt_model(input_ids.clone()) + print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") + # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") + # breakpoint() if args.kv_cache: trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) if args.cudagraph: @@ -269,6 +285,7 @@ def measure_perf(trt_model, input_signature, backend_name): torch_tensorrt.runtime.set_cudagraphs_mode(True) trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + trt_gen_tokens = generate_with_kv_cache( trt_model, trt_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, ) diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 7379e09ba9..e4c4b8af6b 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -59,10 +59,10 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr next_tokens = torch.argmax(next_token_logits, dim=-1) input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) num_tokens_generated += 1 - # # TODO: Handle batch in this check + # TODO: Handle batch in this check if not benchmark and stopping_criteria(input_seq, logits).item(): break - + # breakpoint() return input_seq def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_token_id): @@ -74,19 +74,24 @@ def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_to output_seq = input_signature[0].clone() isl = input_signature[0].shape[1] # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL + logits_concat = [] num_tokens_generated = 0 - while num_tokens_generated <= max_output_seq_length - isl: # end_idx < max_output_seq_length: + while num_tokens_generated < max_output_seq_length - isl: # end_idx < max_output_seq_length: input_signature_with_start_end_idx = input_signature + (start_idx, end_idx) logits_keys_values = model(*input_signature_with_start_end_idx) + num_tokens_generated += 1 logits = logits_keys_values[0] + logits_concat.append(logits) kv_cache = logits_keys_values[1:] next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) input_signature = (next_tokens, *kv_cache) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) start_idx = end_idx end_idx = start_idx + 1 - num_tokens_generated += 1 + lkv = torch.cat(logits_concat, dim=1) + # breakpoint() return output_seq def time_generate( From 3688630a4ae73d7ede09729c844d5ff02684c06c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 16 May 2025 00:32:41 +0000 Subject: [PATCH 06/30] chore: updates --- examples/dynamo/llama3_trt.py | 10 ++-- examples/dynamo/test_sdpa.py | 101 ++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) create mode 100644 examples/dynamo/test_sdpa.py diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py index 30d086757d..a8056cf198 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llama3_trt.py @@ -17,7 +17,7 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList +from transformers import AutoModelForCausalLM, AutoTokenizer from contextlib import nullcontext from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache @@ -26,13 +26,13 @@ def get_model(args): with torch.no_grad(): - if args.model == "meta-llama/Llama-2-7b-hf": + if args.model == "meta-llama/Llama-2-7b-chat-hf": model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=1 + num_hidden_layers=1 ) .eval() .cuda() @@ -271,12 +271,12 @@ def measure_perf(trt_model, input_signature, backend_name): # This import is required to register static/dynamic KV cache transformations as lowering passes import torch_tensorrt.extensions - trt_model = compile_torchtrt(model, input_ids, args) pyt_logits = model.cuda()(input_ids.clone()) + trt_model = compile_torchtrt(model, input_ids, args) trt_logits = trt_model(input_ids.clone()) print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") - # breakpoint() + breakpoint() if args.kv_cache: trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) if args.cudagraph: diff --git a/examples/dynamo/test_sdpa.py b/examples/dynamo/test_sdpa.py new file mode 100644 index 0000000000..f93a08509a --- /dev/null +++ b/examples/dynamo/test_sdpa.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse + +# llama2_model_name = "meta-llama/Llama-2-7b-hf" +llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" +llama_model = AutoModelForCausalLM.from_pretrained( + llama3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval() +LLAMA_CONFIG = llama_model.config + +def test_llama_attention(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention( + config=self.config, + layer_idx=0 + ) + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) + return attn_output + + DTYPE = torch.float32 + model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with torch_tensorrt.logging.debug(): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions={torch.float32}, + debug=True) + trt_output = trt_model(hidden_states, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + +def test_llama_decoder(args): + class LlamaDecoder(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder_layer = LlamaDecoderLayer( + config=self.config, + layer_idx=0 + ) + def forward(self, hidden_states, position_embeddings): + decoder_output = self.decoder_layer(hidden_states, position_embeddings=position_embeddings) + return decoder_output[0] + + DTYPE = torch.float32 + model = LlamaDecoder().eval().cuda().to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + test_llama_attention(args) + # test_llama_decoder(args) \ No newline at end of file From c9f5f2735aeff4333438399f8ec5faea994fc471 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 20 May 2025 02:47:52 +0000 Subject: [PATCH 07/30] chore: refactor updates --- examples/dynamo/llama3_trt.py | 134 +++++---- examples/dynamo/test_if.py | 86 ------ examples/dynamo/test_sdpa.py | 31 ++- examples/dynamo/utils.py | 48 +++- py/torch_tensorrt/dynamo/_compiler.py | 4 +- .../dynamo/conversion/aten_ops_converters.py | 32 +-- .../dynamo/conversion/impl/__init__.py | 1 - .../dynamo/conversion/impl/attention.py | 165 ----------- .../lowering/passes/_aten_lowering_pass.py | 3 +- .../lowering/passes/constant_folding.py | 1 - .../lower_scaled_dot_product_attention.py | 7 +- py/torch_tensorrt/extensions/__init__.py | 1 - py/torch_tensorrt/extensions/hf/__init__.py | 2 - .../extensions/hf/dynamic_cache.py | 249 ----------------- .../extensions/hf/static_cache.py | 258 ------------------ py/torch_tensorrt/extensions/hf/utils.py | 152 ----------- 16 files changed, 125 insertions(+), 1049 deletions(-) delete mode 100644 examples/dynamo/test_if.py delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/attention.py delete mode 100644 py/torch_tensorrt/extensions/__init__.py delete mode 100644 py/torch_tensorrt/extensions/hf/__init__.py delete mode 100644 py/torch_tensorrt/extensions/hf/dynamic_cache.py delete mode 100644 py/torch_tensorrt/extensions/hf/static_cache.py delete mode 100644 py/torch_tensorrt/extensions/hf/utils.py diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py index a8056cf198..8ade81ccaa 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llama3_trt.py @@ -43,7 +43,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - num_hidden_layers=1 + # num_hidden_layers=1 ) .eval() .cuda() @@ -133,28 +133,7 @@ def print_outputs(backend_name, gen_tokens, tokenizer): ) print("===================================") -def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): - """ - Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. - - This function identifies placeholder nodes in the graph that represent KV cache tensors, - and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. - - Args: - model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders - - Returns: - tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph - """ - # placeholder nodes are expected to be in the following order: - # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx - placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] - kv_cache_inputs = placeholder_nodes[1:-2] - zeroed_kv_cache_inputs = [] - for input in kv_cache_inputs: - zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=DEVICE)) - return tuple(zeroed_kv_cache_inputs) def measure_perf(trt_model, input_signature, backend_name): # Measure average time for 10 iterations @@ -207,7 +186,7 @@ def measure_perf(trt_model, input_signature, backend_name): "--min_block_size", type=int, default=1, help="no. of iterations to run" ) arg_parser.add_argument( - "--max_tokens", type=int, default=128, help="no. of iterations to run" + "--max_tokens", type=int, default=128, help="no. of max tokens to be generated" ) arg_parser.add_argument( "--enable_pytorch_run", @@ -229,6 +208,11 @@ def measure_perf(trt_model, input_signature, backend_name): action="store_true", help="Enable debug (default: False)" ) + arg_parser.add_argument( + "--benchmark", + action="store_true", + help="Enable benchmark (default: False)" + ) args = arg_parser.parse_args() with torch.inference_mode(): model = get_model(args) @@ -253,75 +237,79 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_gen_tokens = generate( model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id ) - - pyt_timings = time_generate( - generate, - model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - pyt_stats = recordStats( - "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None - ) + + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None + ) # TRT + from lower_sdpa import * if args.kv_cache: # This import is required to register static/dynamic KV cache transformations as lowering passes - import torch_tensorrt.extensions - - pyt_logits = model.cuda()(input_ids.clone()) - trt_model = compile_torchtrt(model, input_ids, args) - trt_logits = trt_model(input_ids.clone()) - print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") - # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") - breakpoint() + from static_cache import * + trt_model = compile_torchtrt(model, input_ids, args) + else: + # pyt_logits = model.cuda()(input_ids.clone()) + trt_model = compile_torchtrt(model, input_ids, args) + # trt_logits = trt_model(input_ids.clone(), True) + # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") + # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") if args.kv_cache: - trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) torch_tensorrt.runtime.set_cudagraphs_mode(True) - - trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - + trt_gen_tokens = generate_with_kv_cache( - trt_model, trt_input_signature, MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, ) - trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - trt_timings = time_generate( - generate_with_kv_cache, - trt_model, - trt_input_signature, - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) + if args.benchmark: + trt_timings = time_generate( + generate_with_kv_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) else: trt_gen_tokens = generate( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, ) - trt_timings = time_generate( - generate, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = recordStats( + "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None ) - trt_stats = recordStats( - "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None - ) if args.enable_pytorch_run: print_outputs("PyTorch", pyt_gen_tokens, tokenizer) print_outputs("TensorRT", trt_gen_tokens, tokenizer) - if args.enable_pytorch_run: - print("=========PyTorch PERFORMANCE============ \n") - print(pyt_stats) + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) print("===================== \n") - print("=========TensorRT PERFORMANCE============ \n") - print(trt_stats) + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/examples/dynamo/test_if.py b/examples/dynamo/test_if.py deleted file mode 100644 index ca573f330c..0000000000 --- a/examples/dynamo/test_if.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.nn as nn -from torch.export import export - -class ConditionalModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, q, k, v, k1, v1, flag): - def true_fn(q, k, v, k1, v1): - k_new = torch.cat((k, k1), dim=2) - v_new = torch.cat((v, v1), dim=2) - return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) - - def false_fn(q, k, v, k1, v1): - return torch._C._nn.scaled_dot_product_attention(q, k, v) - - out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) - - return 2 * out - -class ConditionalModel2(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, q, k, v, k1, v1, start_idx, end_idx): - - new_k1 = torch.cat((k1[:, :, :start_idx, :], k, k1[:, :, end_idx:, :]), dim=2) - new_v1 = torch.cat((v1[:, :, :start_idx, :], v, v1[:, :, end_idx:, :]), dim=2) - out = torch._C._nn.scaled_dot_product_attention(q, new_k1[:,:,:end_idx,:], new_v1[:,:,:end_idx,:]) - - return out, new_k1, new_v1 - - -def main(): - # Create model - model = ConditionalModel2() - model.eval() # Set to evaluation mode - - # Create example inputs - q = torch.randn(1, 32, 2048, 64).cuda() - k = torch.randn(1, 32, 2048, 64).cuda() - v = torch.randn(1, 32, 2048, 64).cuda() - k1 = torch.zeros(1, 32, 2176, 64).cuda() - v1 = torch.zeros(1, 32, 2176, 64).cuda() - # example_flag = torch.tensor(True) - start_idx = 0 - end_idx = 2048 - out_pyt = model(q, k, v, k1, v1, start_idx, end_idx) - out_pyt2 = model(q, k, v, k1, v1, 17, 18) - - exported_program = export( - model, - args=(q, k, v, k1, v1, start_idx, end_idx), - dynamic_shapes=(None, None, None, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), - strict=False - ) - import torch_tensorrt - with torch_tensorrt.logging.debug(): - trt_model = torch_tensorrt.dynamo.compile( - exported_program, - inputs=[q, k, v, k1, v1, start_idx, end_idx], - enabled_precisions={torch.float32}, - # truncate_double=True, - disable_tf32=True, - use_python_runtime=True, - debug=True, - min_block_size=1, - ) - - gm = exported_program.module() - breakpoint() - out_ep = gm(q, k, v, k1, v1, start_idx, end_idx) - out_ep2 = gm(q, k, v, k1, v1, 2048, 2049) - out_trt = trt_model(q, k, v, k1, v1, start_idx, end_idx) - out_trt2 = trt_model(q, k, v, k1, v1, 2048, 2049) - # breakpoint() - # Print the graph - print("\nExported Graph:") - print(exported_program.graph_module.graph) - # breakpoint() - # print("done") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/dynamo/test_sdpa.py b/examples/dynamo/test_sdpa.py index f93a08509a..c8e3811925 100644 --- a/examples/dynamo/test_sdpa.py +++ b/examples/dynamo/test_sdpa.py @@ -16,7 +16,7 @@ use_cache=False, attn_implementation="sdpa", num_hidden_layers=1, - ).eval() + ).eval().cuda() LLAMA_CONFIG = llama_model.config def test_llama_attention(args): @@ -33,24 +33,33 @@ def forward(self, hidden_states, position_embeddings): return attn_output DTYPE = torch.float32 - model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + model = llama_model.model.layers[0].self_attn.to(DTYPE) # llama3 - hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings) + # hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + # position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + hidden_states = torch.load("hidden_states.pt") + position_embeddings = torch.load("position_embeddings.pt") + # breakpoint() + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) with torch_tensorrt.logging.debug(): trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings], + inputs=[hidden_states, position_embeddings, None], enabled_precisions={torch.float32}, + disable_tf32=True, debug=True) - trt_output = trt_model(hidden_states, position_embeddings) + trt_output = trt_model(hidden_states, position_embeddings, None) + breakpoint() + if isinstance(pyt_output, tuple): + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + else: + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") def test_llama_decoder(args): class LlamaDecoder(nn.Module): diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index e4c4b8af6b..26c7c4b605 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -38,6 +38,30 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): return ep +def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids and is_causal. The last two inputs are start_idx and end_idx. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-2] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) + + return tuple(zeroed_kv_cache_inputs) + def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): """ @@ -62,36 +86,36 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr # TODO: Handle batch in this check if not benchmark and stopping_criteria(input_seq, logits).item(): break - # breakpoint() + return input_seq -def generate_with_kv_cache(model, input_signature, max_output_seq_length, eos_token_id): +def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id): """ Greedy decoding of the model with KV cache. """ start_idx = 0 - end_idx = input_signature[0].shape[1] - output_seq = input_signature[0].clone() - isl = input_signature[0].shape[1] + end_idx = input_seq.shape[1] + output_seq = input_seq.clone() # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL logits_concat = [] num_tokens_generated = 0 - while num_tokens_generated < max_output_seq_length - isl: # end_idx < max_output_seq_length: - input_signature_with_start_end_idx = input_signature + (start_idx, end_idx) - logits_keys_values = model(*input_signature_with_start_end_idx) + + kv_cache = get_zeroed_kv_cache_inputs(model) + while end_idx < max_output_seq_length: + is_causal = True if input_seq.shape[1] > 1 else False + input_signature = (input_seq, is_causal, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) num_tokens_generated += 1 logits = logits_keys_values[0] logits_concat.append(logits) kv_cache = logits_keys_values[1:] next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) - input_signature = (next_tokens, *kv_cache) - output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens start_idx = end_idx end_idx = start_idx + 1 - lkv = torch.cat(logits_concat, dim=1) - # breakpoint() + return output_seq def time_generate( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b95ae9e3cb..15acf99cb0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -684,7 +684,6 @@ def compile( ) gm = exported_program.module() - exported_program.module().to("cpu") torch.cuda.empty_cache() import gc @@ -797,7 +796,6 @@ def preserve_module_specs(in_spec, out_spec, target_module): return target_module - # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -1178,7 +1176,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, } - + settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 05b4582191..4046f5c54d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1896,37 +1896,7 @@ def aten_ops_minimum( args[1], ) -def attention_validator( - node: Node, settings: Optional[CompilationSettings] = None -) -> bool: - # Currently, `attn_mask` is not supported - return args_bounds_check(node.args, 3) is None - -@dynamo_tensorrt_converter( - torch.nn.functional.scaled_dot_product_attention, - capability_validator=attention_validator, - supports_dynamic_shapes=True, -) -def tensorrt_scaled_dot_product_attention( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.attention.scaled_dot_product_attention( - ctx, - target, - SourceIR.TORCHTRT_LOWERED, - name, - args[0], - args[1], - args[2], - args_bounds_check(args, 5, False), - kwargs.get("scale", None), - ) - - +@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True) def aten_ops_sub( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index a8b0fbe284..df580b1516 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -1,6 +1,5 @@ from torch_tensorrt.dynamo.conversion.impl import ( activation, - attention, addmm, arange, cast, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py deleted file mode 100644 index 71dfb5f818..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ /dev/null @@ -1,165 +0,0 @@ -import math -from typing import Optional, Union - -import numpy as np -import tensorrt as trt -from torch.fx.node import Target -from torch_tensorrt._enums import dtype -from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - cast_trt_tensor, - get_trt_tensor, -) -from torch_tensorrt.fx.types import TRTTensor - - -def tril( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - # the lower triangle of the tensor means the rows greater than and equal to the cols - row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) - col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) - rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 - ) - # get the rows - row_tensor = impl.elementwise.trunc_div( - ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col - ) - # get the cols - col_tensor = impl.elementwise.fmod( - ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col - ) - cond = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_tensor, col_tensor - ) - return impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", cond, [row, col] - ) - - -def scaled_dot_product_attention( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - query: TRTTensor, - key: TRTTensor, - value: TRTTensor, - is_causal: bool, - scale: Optional[float], -) -> TRTTensor: - # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) - if scale is None: - scale = query.shape[-1] - if scale < 0: - # dynamic shape - scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) - sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) - else: - # static shape - sqrt_scaled = math.sqrt(scale) - scaled = impl.elementwise.div( - ctx, - target, - source_ir, - name + "_scale", - mm, - sqrt_scaled, - ) - else: - scaled = impl.elementwise.mul( - ctx, - target, - source_ir, - name + "_scale", - mm, - scale, - ) - - if is_causal: - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, -2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) - - LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) - - # this is to generate a tensor which has shape (L, S), type is int32 - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 - ) - shape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] - ) - - # since we want our attn_bias to be in float32, so cast it to float32 - shape_tensor = cast_trt_tensor( - ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir - ) - - # initialize the attn_bias as the zeros tensor - attn_bias = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 - ) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - inf_tensor = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") - ) - cond = impl.elementwise.eq( - ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) - ) - # mask out the certain part of the attn_bias - attn_bias = impl.condition.select( - ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond - ) - - scaled = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias - ) - - softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", scaled, -1, False - ) - out = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_out", - softmax, - value, - ) - - return out \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index e47ecb6191..6e2019ad71 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -26,8 +26,7 @@ replace_max_pool_with_indices, remove_assert_nodes, accumulate_fp32_matmul, - lower_scaled_dot_product_attention, - remove_num_users_is_0_nodes, + # remove_num_users_is_0_nodes, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..172d902a40 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -55,7 +55,6 @@ def constant_fold( del cf logger.debug(f"Graph after constant folding:\n{gm.graph}") - return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index ee7651cb8c..89558acade 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -25,6 +25,8 @@ def lower_scaled_dot_product_attention( """ original_fns, replacement = scaled_dot_product_attention_replacement() replaced_nodes = [] + sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default] + breakpoint() # For each original function, search for it in the graph and replace for original in original_fns: replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( @@ -33,7 +35,7 @@ def lower_scaled_dot_product_attention( replacement, ignore_literals=True, ) - + breakpoint() if replaced_nodes: # Repair instances which use the kwargs field (specifically the "scale" kwarg) # Also repair instances which specified the is_causal or attn_bias fields @@ -69,8 +71,9 @@ def lower_scaled_dot_product_attention( # Set default args in new node: # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + breakpoint() new_attention_node.args = new_attention_node.args + (None, 0.0, False) - + breakpoint() # The `is_causal` argument was specified if ( ( diff --git a/py/torch_tensorrt/extensions/__init__.py b/py/torch_tensorrt/extensions/__init__.py deleted file mode 100644 index 80112f1526..0000000000 --- a/py/torch_tensorrt/extensions/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import hf \ No newline at end of file diff --git a/py/torch_tensorrt/extensions/hf/__init__.py b/py/torch_tensorrt/extensions/hf/__init__.py deleted file mode 100644 index 19d2a493ea..0000000000 --- a/py/torch_tensorrt/extensions/hf/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .static_cache import * -# from .dynamic_cache import * \ No newline at end of file diff --git a/py/torch_tensorrt/extensions/hf/dynamic_cache.py b/py/torch_tensorrt/extensions/hf/dynamic_cache.py deleted file mode 100644 index b7348157ec..0000000000 --- a/py/torch_tensorrt/extensions/hf/dynamic_cache.py +++ /dev/null @@ -1,249 +0,0 @@ -import logging -from typing import Dict, List, Tuple, Union, Sequence, Any - -import torch -from torch.fx.node import Target - -import torch_tensorrt -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -from .utils import add_graph_input, create_random_output_tensors, get_kv_nodes -import tensorrt -import torch.utils._pytree as pytree -logger = logging.getLogger(__name__) - -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) -def cond_converter( - ctx: torch_tensorrt.dynamo.conversion.ConversionContext, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: str, -) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: - """ - Converter for torch.ops.higher_order.cond operation to TensorRT. - - This function handles the conversion of PyTorch's conditional operation to TensorRT. - The conditional operation selects between two tensors based on a boolean predicate. - - Args: - ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context - target (Target): The target operation to convert - args (Tuple[Argument, ...]): The arguments to the operation - kwargs (Dict[str, Argument]): The keyword arguments to the operation - name (str): The name to give to the TensorRT layer - - Returns: - Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) - """ - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = args[0], args[1], args[2] - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - output = output_layer.get_output(0) - - return output - -def add_kv_as_outputs(gm): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - # list of MHA kernels we would want to detect and replace - mha_ops = { - torch._C._nn.scaled_dot_product_attention, - } - - # Find all SDPA nodes in the graph - mha_nodes = [] - for node in gm.graph.nodes: - if is_op(node, mha_ops): - mha_nodes.append(node) - - # Iterate through each MHA node to extract shape information - for mha_node in mha_nodes: - if "val" in mha_node.meta and len(mha_node.args) >= 3: - # Get the input nodes (query, key, value) - q_node, k_node, v_node = mha_node.args[:3] - - # Add the copy nodes as outputs to the graph - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + ((k_node, v_node),) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs, (k_node, v_node)) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - - - -def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors and index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx") - end_idx_input = add_graph_input(gm, "end_idx") - return kv_inputs, start_idx_input, end_idx_input - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Insert slicing operations before each scaled_dot_product_attention operation. - """ - pass - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - sdpa_nodes.append(node) - - for idx, sdpa_node in enumerate(sdpa_nodes): - - -def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Insert a torch.cond operation before each scaled_dot_product_attention operation. - - Args: - gm: The FX GraphModule to modify - - Returns: - The modified GraphModule - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - sdpa_nodes.append(node) - - # For each SDPA node, insert a torch.cond operation before it - for idx, sdpa_node in enumerate(sdpa_nodes): - - with gm.graph.inserting_before(sdpa_node): - pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) - q_node, k_node, v_node = sdpa_node.args[:3] - incoming_key, incoming_value = incoming_keys_values[idx] - # Create nodes for concatenating k with incoming_key and v with incoming_value - concatenated_k_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([k_node, incoming_key], 2), # Concatenate along sequence length dimension - kwargs={} - ) - concatenated_v_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([v_node, incoming_value], 2), # Concatenate along sequence length dimension - kwargs={} - ) - - # Create the torch.cond node - cond_k_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(pred_node, concatenated_k_node, k_node), - ) - - cond_v_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(pred_node, concatenated_v_node, v_node), - ) - - sdpa_node.args = (q_node, cond_k_node, cond_v_node) - - return gm - - - -@_aten_lowering_pass -def insert_dynamic_kv_cache( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert FlashInfer MHA + KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - - # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) - - # Call the function to add QKV as outputs - logits_keys_values = add_kv_as_outputs(gm, start_idx_input, end_idx_input) - - gm = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - # gm = insert_torch_cond_before_sdpa(gm, kv_inputs) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm - - diff --git a/py/torch_tensorrt/extensions/hf/static_cache.py b/py/torch_tensorrt/extensions/hf/static_cache.py deleted file mode 100644 index 55e10d0377..0000000000 --- a/py/torch_tensorrt/extensions/hf/static_cache.py +++ /dev/null @@ -1,258 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch.fx import Node - -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -import torch.utils._pytree as pytree -from .utils import add_graph_input, create_random_output_tensors, get_kv_nodes -logger = logging.getLogger(__name__) - - -def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + tuple(kv_cache_for_graph) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - -def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors and index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - - # Get input_ids metadata from the graph. The first placeholder node in the graph is the input_ids node - # find the first placeholder node, then pull out meta["val"] - placeholder_node = next(node for node in gm.graph.nodes if node.op == "placeholder") - input_ids_meta = placeholder_node.meta["val"] - seq_len = input_ids_meta.shape[1] - - # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive - shape_env = seq_len.node.shape_env - start_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(start_idx_unbacked_symint >= 0) - torch._check(start_idx_unbacked_symint <= seq_len) - - end_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(end_idx_unbacked_symint >= 0) - torch._check(end_idx_unbacked_symint <= seq_len) - # Set the symbolic ints as the metadata for start_idx and end_idx inputs - start_idx_input.meta["val"] = start_idx_unbacked_symint - end_idx_input.meta["val"] = end_idx_unbacked_symint - - return kv_inputs, start_idx_input, end_idx_input - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): - """ - Insert slicing operations before each scaled_dot_product_attention operation. - """ - pass - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - sdpa_nodes.append(node) - kv_cache_for_graph = [] - for idx, sdpa_node in enumerate(sdpa_nodes): - q_node, k_node, v_node = sdpa_node.args[:3] - incoming_key, incoming_value = incoming_keys_values[idx] - kv_cache_for_sdpa_node = [] - new_keys_values = [] - for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): - # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - with gm.graph.inserting_before(sdpa_node): - slice_1 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_2 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_1, 1), - kwargs={} - ) - slice_3 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_2, 2, None, start_idx_input), - kwargs={} - ) - slice_4 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_3, 3), - kwargs={} - ) - # =============================================== # - # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - slice_5 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_6 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_5, 1), - kwargs={} - ) - slice_7 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_6, 2, end_idx_input), - kwargs={} - ) - slice_8 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_7, 3), - kwargs={} - ) - # =============================================== # - # Concatenate the sliced tensors to build KV cache - cat = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([slice_4, current_key_or_value_node, slice_8], 2), - kwargs={} - ) - # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph - cat.meta.update(key_or_value.meta) - kv_cache_for_sdpa_node.append(cat) - # =============================================== # - # Get the current key and value by indexing the KV cache - slice_9 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(cat,), - kwargs={} - ) - slice_10 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_9, 1), - kwargs={} - ) - slice_11 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_10, 2, None, end_idx_input), - kwargs={} - ) - slice_12 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_11, 3), - kwargs={} - ) - new_keys_values.append(slice_12) - - kv_cache_for_graph.extend(kv_cache_for_sdpa_node) - sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) - - return gm, kv_cache_for_graph - - -@_aten_lowering_pass -def insert_kv_cache( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) - - # Build and update the KV cache using computed KV inputs for current token and - # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - - # Call the function to add QKV as outputs - logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) - - gm = clean_up_graph_after_modifications(gm) - new_output_tensors = create_random_output_tensors(logits_keys_values) - - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm - - diff --git a/py/torch_tensorrt/extensions/hf/utils.py b/py/torch_tensorrt/extensions/hf/utils.py deleted file mode 100644 index 7a17ad7e65..0000000000 --- a/py/torch_tensorrt/extensions/hf/utils.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -from torch.fx import Graph, GraphModule, Node -from typing import Optional, Union, Iterable, List, Tuple -from torch._ops import OpOverloadPacket -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.fx.passes.shape_prop import _extract_tensor_metadata -from torch.utils._pytree import _LEAF_SPEC -from torch._export.utils import _detect_fake_mode_from_gm - -def get_kv_nodes(gm): - """ - Get the key and value nodes from the graph. - """ - kv_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - q_node, k_node, v_node = node.args[:3] - kv_nodes.append((k_node, v_node)) - return kv_nodes - -def get_random_tensor_from_node(node: Node) -> torch.Tensor: - """ - Creates a random tensor based on the shape information in a node's metadata. - For symbolic dimensions, extracts the maximum value from the shape environment. - - Args: - node: A torch.fx.Node object with metadata containing tensor information - - Returns: - A random tensor with shape matching the node's metadata, or None if no valid - tensor information is found - """ - if "val" not in node.meta: - raise ValueError(f"No tensor information found in node metadata for node: {node}") - - fake_tensor = node.meta["val"] - shape = [] - - # Iterate through each dimension and handle symbolic dimensions - for dim in fake_tensor.shape: - if isinstance(dim, torch.SymInt): - # Extract the maximum value from the shape environment - max_val = dim.node.hint - shape.append(max_val) - else: - shape.append(dim) - - # Create a random tensor with the determined shape - dtype = fake_tensor.dtype - device = fake_tensor.device - random_tensor = torch.rand(shape, dtype=dtype, device=device) - - return random_tensor - -def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: - """ - Creates random tensors based on the shape information in node metadata. - For symbolic dimensions, extracts the maximum value from the shape environment. - - Args: - nodes: List of torch.fx.Node objects with metadata - - Returns: - List of random tensors with shapes matching the nodes' metadata - """ - random_tensors = [] - - for node in nodes: - if isinstance(node, Node): - node_tensor = get_random_tensor_from_node(node) - elif isinstance(node, tuple): - node_tensor_list = [] - for n in node: - random_tensor = get_random_tensor_from_node(n) - node_tensor_list.append(random_tensor) - node_tensor = tuple(node_tensor_list) - - random_tensors.append(node_tensor) - - return random_tensors - -def add_graph_input( - gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None -) -> Node: - """Add a graph input to the given GraphModule and return the newly created node. - - NOTE: function does NOT do any graph canonicalization. This is left to the user! - - Args: - gm (GraphModule): The GraphModule to add the input to. - name (str): The name of the input. - val (torch.Tensor): An example tensor to use for the input. - dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] - """ - # check that no dynamic shape is provided... - if dynamic_shape: - raise NotImplementedError("Dynamic shape not supported for adding graph inputs") - - # extract graph and input spec - graph: Graph = gm.graph - - in_spec = graph._codegen.pytree_info.in_spec - in_spec_for_args = in_spec.children_specs[0] - orig_args = graph._codegen.pytree_info.orig_args - assert in_spec_for_args.type is tuple - - # insert input node after currently last input node - node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] - with graph.inserting_after(node_last_input): - in_node = graph.placeholder(name) - in_spec_for_args.children_specs.append(_LEAF_SPEC) - orig_args.append(f"arg_{name}") - - # update pytree info recursively with __post_init__ starting at leaves - def call_post_init(spec): - for child_spec in spec.children_specs: - call_post_init(child_spec) - spec.__post_init__() - - call_post_init(in_spec) - - # set fake tensor information if all required information is available - fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) - if fake_mode and val is not None and isinstance(val, torch.Tensor): - if isinstance(val, FakeTensor): - fake_tensor = val - else: - fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) - in_node.meta["val"] = fake_tensor - in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) - - # return new node... - return in_node - -def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: - """Check if the node is a call to one of the ops.""" - if node.op != "call_function": - return False - # check if it's a single op that's provided - if isinstance(ops, OpOverloadPacket): - ops = [ops] - - # check if it's the op itself instead of an overload - if any(node.target == op for op in ops): - return True - - return False - -def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: - input_nodes: List[Node] = graph.find_nodes(op="placeholder") - output_nodes: List[Node] = graph.find_nodes(op="output") - return (input_nodes, output_nodes) \ No newline at end of file From 0cb0dccef99899e4e2bf7e9187d1aa3e5415a2c0 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 20 May 2025 02:50:13 +0000 Subject: [PATCH 08/30] chore: refactor updates --- examples/dynamo/dynamic_cache.py | 249 +++++++++++++++++++++++++ examples/dynamo/lower_sdpa.py | 67 +++++++ examples/dynamo/sdpa_converter.py | 181 ++++++++++++++++++ examples/dynamo/static_cache.py | 262 +++++++++++++++++++++++++++ examples/dynamo/test_static_cache.py | 170 +++++++++++++++++ 5 files changed, 929 insertions(+) create mode 100644 examples/dynamo/dynamic_cache.py create mode 100644 examples/dynamo/lower_sdpa.py create mode 100644 examples/dynamo/sdpa_converter.py create mode 100644 examples/dynamo/static_cache.py create mode 100644 examples/dynamo/test_static_cache.py diff --git a/examples/dynamo/dynamic_cache.py b/examples/dynamo/dynamic_cache.py new file mode 100644 index 0000000000..c678bac454 --- /dev/null +++ b/examples/dynamo/dynamic_cache.py @@ -0,0 +1,249 @@ +import logging +from typing import Dict, List, Tuple, Union, Sequence, Any + +import torch +from torch.fx.node import Target + +import torch_tensorrt +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +from .cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +import tensorrt +import torch.utils._pytree as pytree +logger = logging.getLogger(__name__) + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output + +def add_kv_as_outputs(gm): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + # list of MHA kernels we would want to detect and replace + mha_ops = { + torch._C._nn.scaled_dot_product_attention, + } + + # Find all SDPA nodes in the graph + mha_nodes = [] + for node in gm.graph.nodes: + if is_op(node, mha_ops): + mha_nodes.append(node) + + # Iterate through each MHA node to extract shape information + for mha_node in mha_nodes: + if "val" in mha_node.meta and len(mha_node.args) >= 3: + # Get the input nodes (query, key, value) + q_node, k_node, v_node = mha_node.args[:3] + + # Add the copy nodes as outputs to the graph + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + ((k_node, v_node),) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs, (k_node, v_node)) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + + + +def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors and index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx") + end_idx_input = add_graph_input(gm, "end_idx") + return kv_inputs, start_idx_input, end_idx_input + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + pass + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + + for idx, sdpa_node in enumerate(sdpa_nodes): + + +def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Insert a torch.cond operation before each scaled_dot_product_attention operation. + + Args: + gm: The FX GraphModule to modify + + Returns: + The modified GraphModule + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + + # For each SDPA node, insert a torch.cond operation before it + for idx, sdpa_node in enumerate(sdpa_nodes): + + with gm.graph.inserting_before(sdpa_node): + pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + # Create nodes for concatenating k with incoming_key and v with incoming_value + concatenated_k_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([k_node, incoming_key], 2), # Concatenate along sequence length dimension + kwargs={} + ) + concatenated_v_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([v_node, incoming_value], 2), # Concatenate along sequence length dimension + kwargs={} + ) + + # Create the torch.cond node + cond_k_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(pred_node, concatenated_k_node, k_node), + ) + + cond_v_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(pred_node, concatenated_v_node, v_node), + ) + + sdpa_node.args = (q_node, cond_k_node, cond_v_node) + + return gm + + + +@_aten_lowering_pass +def insert_dynamic_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert FlashInfer MHA + KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + + # Call the function to add QKV as outputs + logits_keys_values = add_kv_as_outputs(gm, start_idx_input, end_idx_input) + + gm = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + # gm = insert_torch_cond_before_sdpa(gm, kv_inputs) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm + + diff --git a/examples/dynamo/lower_sdpa.py b/examples/dynamo/lower_sdpa.py new file mode 100644 index 0000000000..cd2b253265 --- /dev/null +++ b/examples/dynamo/lower_sdpa.py @@ -0,0 +1,67 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from cache_utils import add_graph_input + +from sdpa_converter import * +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + +@_aten_lowering_pass +def replace_variants_of_sdpa( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace scaled_dot_product_attention with an equivalent + implementation which can be accurately converted to TRT + """ + # If sdpa replacement is found, add is_causal_input only once in the graph + is_causal_input = None + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + + # is_causal_input is None if this is the first sdpa node in the graph, otherwise it is reused across all sdpa nodes + if is_causal_input is None: + # Add a new input to the graph for is_causal + is_causal_input = add_graph_input(gm, "is_causal", torch.tensor(True)) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + # The input args is (query, key, value, is_causal). kwargs has scale + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=node.args[:3] + (is_causal_input,), + kwargs={"scale": node.kwargs.get("scale", None)} + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + + # Clean up the graph + clean_up_graph_after_modifications(gm) + logger.info("Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention") + return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py new file mode 100644 index 0000000000..2f52a08c6a --- /dev/null +++ b/examples/dynamo/sdpa_converter.py @@ -0,0 +1,181 @@ +import math +from typing import Optional, Union, Tuple, Any, Dict +import torch +import torch_tensorrt +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor +import logging +logger = logging.getLogger(__name__) + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # the lower triangle of the tensor means the rows greater than and equal to the cols + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 + ) + # get the rows + row_tensor = impl.elementwise.trunc_div( + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col + ) + # get the cols + col_tensor = impl.elementwise.fmod( + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col + ) + cond = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor + ) + return impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", cond, [row, col] + ) + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.nn.functional.scaled_dot_product_attention, enabled=True, supports_dynamic_shapes=True) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, is_causal = args + logger.info("Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time") + + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + # If is_causal is True, we need to generate a causal mask + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, -2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) + + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) + + # this is to generate a tensor which has shape (L, S), type is int32 + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 + ) + shape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] + ) + + # since we want our attn_bias to be in float32, so cast it to float32 + shape_tensor = cast_trt_tensor( + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir + ) + + # initialize the attn_bias as the zeros tensor + attn_bias = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 + ) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + inf_tensor = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") + ) + cond = impl.elementwise.eq( + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) + ) + # mask out the certain part of the attn_bias + attn_bias = impl.condition.select( + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + + # Create a if condition to check if is_causal is True + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + attn_weights = output_layer.get_output(0) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", attn_weights, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out \ No newline at end of file diff --git a/examples/dynamo/static_cache.py b/examples/dynamo/static_cache.py new file mode 100644 index 0000000000..430a2000d7 --- /dev/null +++ b/examples/dynamo/static_cache.py @@ -0,0 +1,262 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters and is_causal as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + - is_causal input node + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get input_ids metadata from the graph. The first placeholder node in the graph is the input_ids node + # find the first placeholder node, then pull out meta["val"] + placeholder_node = next(node for node in gm.graph.nodes if node.op == "placeholder") + input_ids_meta = placeholder_node.meta["val"] + seq_len = input_ids_meta.shape[1] + + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + shape_env = seq_len.node.shape_env + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + pass + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(cat,), + kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={} + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={} + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={} + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + sdpa_node.args[3:] + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + + # Call the function to add QKV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm + + diff --git a/examples/dynamo/test_static_cache.py b/examples/dynamo/test_static_cache.py new file mode 100644 index 0000000000..35c89fdc01 --- /dev/null +++ b/examples/dynamo/test_static_cache.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +from torch.export import export + +class DynamicCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, flag): + def true_fn(q, k, v, k1, v1): + k_new = torch.cat((k, k1), dim=2) + v_new = torch.cat((v, v1), dim=2) + return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) + + def false_fn(q, k, v, k1, v1): + return torch._C._nn.scaled_dot_product_attention(q, k, v) + + out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) + + return 2 * out + +class ModelNoCache(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) + +class StaticCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + + return out, new_key_cache, new_value_cache + + +def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + """ + Eager implementation of SDPA + """ + import math + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +def print_diff(tensor1, tensor2): + """ + Print the diff between two tensors + """ + print(f"[Diff] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_static_cache(): + """ + Test the static cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Test Prefill + start_idx = 0 + end_idx = 2048 + out_no_cache = model_no_cache(q, k, v) + out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) + print(f"[Prefill] Diff between no cache and static cache: {torch.mean(torch.abs(out_no_cache - out_static_cache))}") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False) + print_diff(out_no_cache[:, :, -1:, :], out_static_cache) + q = q_full + k = k_full + v = v_full + + +def main(): + # Create model + # model = ConditionalModel2() + # model.eval() # Set to evaluation mode + with torch.inference_mode(): + test_static_cache() + # # Create example inputs + # q = torch.randn(1, 32, 2048, 64).cuda() + # k = torch.randn(1, 32, 2048, 64).cuda() + # v = torch.randn(1, 32, 2048, 64).cuda() + # k1 = torch.zeros(1, 32, 2176, 64).cuda() + # v1 = torch.zeros(1, 32, 2176, 64).cuda() + # # example_flag = torch.tensor(True) + # start_idx = 0 + # end_idx = 2048 + # out_pyt = model(q, k, v, k1, v1, start_idx, end_idx) + # out_pyt2 = model(q, k, v, k1, v1, 17, 18) + + # exported_program = export( + # model, + # args=(q, k, v, k1, v1, start_idx, end_idx), + # dynamic_shapes=(None, None, None, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), + # strict=False + # ) + # import torch_tensorrt + # with torch_tensorrt.logging.debug(): + # trt_model = torch_tensorrt.dynamo.compile( + # exported_program, + # inputs=[q, k, v, k1, v1, start_idx, end_idx], + # enabled_precisions={torch.float32}, + # # truncate_double=True, + # disable_tf32=True, + # use_python_runtime=True, + # debug=True, + # min_block_size=1, + # ) + + # gm = exported_program.module() + # breakpoint() + # out_ep = gm(q, k, v, k1, v1, start_idx, end_idx) + # out_ep2 = gm(q, k, v, k1, v1, 2048, 2049) + # out_trt = trt_model(q, k, v, k1, v1, start_idx, end_idx) + # out_trt2 = trt_model(q, k, v, k1, v1, 2048, 2049) + # # breakpoint() + # # Print the graph + # print("\nExported Graph:") + # print(exported_program.graph_module.graph) + # # breakpoint() + # # print("done") + + +if __name__ == "__main__": + main() \ No newline at end of file From 6cbb1bd520ba9e7e22ff738bc90d7f94cd89792d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 20 May 2025 23:26:28 +0000 Subject: [PATCH 09/30] chore: updates --- examples/dynamo/lower_sdpa.py | 19 +- examples/dynamo/sdpa_converter.py | 14 +- examples/dynamo/static_cache.py | 23 +-- examples/dynamo/test_static_cache.py | 268 +++++++++++++++++++++------ examples/dynamo/utils.py | 4 +- 5 files changed, 253 insertions(+), 75 deletions(-) diff --git a/examples/dynamo/lower_sdpa.py b/examples/dynamo/lower_sdpa.py index cd2b253265..c71b168e6f 100644 --- a/examples/dynamo/lower_sdpa.py +++ b/examples/dynamo/lower_sdpa.py @@ -32,18 +32,29 @@ def replace_variants_of_sdpa( is_causal_input = None for node in gm.graph.nodes: if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - # is_causal_input is None if this is the first sdpa node in the graph, otherwise it is reused across all sdpa nodes if is_causal_input is None: # Add a new input to the graph for is_causal - is_causal_input = add_graph_input(gm, "is_causal", torch.tensor(True)) + is_causal_input = add_graph_input(gm, "is_causal", True) + is_causal_input.meta["val"] = torch.tensor(True) + + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if len(node.args) == 7: + query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + query, key, value, dropout_p, is_causal, return_debug_mask = node.args + + modified_input_args = (query, key, value, None, dropout_p, is_causal_input) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, - args=node.args[:3] + (is_causal_input,), + args=modified_input_args, kwargs={"scale": node.kwargs.get("scale", None)} ) @@ -56,6 +67,7 @@ def replace_variants_of_sdpa( if user.args[1] == 0: # Replace all uses of the getitem with the new attention node user.replace_all_uses_with(new_node) + new_node.meta['val'] = new_node.meta['val'][0] # Replace all uses of the original node with the new node node.replace_all_uses_with(new_node) @@ -63,5 +75,6 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) + logger.info("Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention") return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 2f52a08c6a..814b9fe26b 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -56,7 +56,7 @@ def scaled_dot_product_attention( name: str, ) -> TRTTensor: # TODO: Handle attn_mask and is_causal arguments in the future - query, key, value, is_causal = args + query, key, value, attn_mask, dropout_p, is_causal = args logger.info("Ignoring attn_mask and is_causal arguments provided by the original graph. " "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " "and for generate phase, is_causal=False since we pass only 1 input token at a time") @@ -160,14 +160,14 @@ def scaled_dot_product_attention( ) # Create a if condition to check if is_causal is True - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - attn_weights = output_layer.get_output(0) + # if_layer = ctx.net.add_if_conditional() + # condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + # if_layer.set_condition(condition) + # output_layer = if_layer.add_output(true_branch, false_branch) + # attn_weights = output_layer.get_output(0) softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", attn_weights, -1, False + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) out = impl.matmul.matrix_multiply( ctx, diff --git a/examples/dynamo/static_cache.py b/examples/dynamo/static_cache.py index 430a2000d7..59b84eee86 100644 --- a/examples/dynamo/static_cache.py +++ b/examples/dynamo/static_cache.py @@ -53,7 +53,7 @@ def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Ten def add_kv_cache_inputs(gm, fixed_kv: bool = True): """ - Add key-value tensors, index parameters and is_causal as inputs to the graph. + Add key-value tensors, index parameters as inputs to the graph. Args: gm: The GraphModule to modify @@ -64,7 +64,6 @@ def add_kv_cache_inputs(gm, fixed_kv: bool = True): - List of (k_input, v_input) node pairs for each SDPA operation - start_idx input node for slicing operations - end_idx input node for slicing operations - - is_causal input node """ def get_static_tensor(tensor: torch.Tensor): @@ -98,21 +97,23 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - # Get input_ids metadata from the graph. The first placeholder node in the graph is the input_ids node - # find the first placeholder node, then pull out meta["val"] - placeholder_node = next(node for node in gm.graph.nodes if node.op == "placeholder") - input_ids_meta = placeholder_node.meta["val"] - seq_len = input_ids_meta.shape[1] + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache, value_cache, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + input_ids_meta = input_nodes[0].meta["val"] + seq_len = input_ids_meta.shape[2] + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + from torch.fx.experimental.symbolic_shapes import ShapeEnv + shape_env = ShapeEnv() # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive - shape_env = seq_len.node.shape_env start_idx_unbacked_symint = shape_env.create_unbacked_symint() torch._check(start_idx_unbacked_symint >= 0) - torch._check(start_idx_unbacked_symint <= seq_len) + torch._check(start_idx_unbacked_symint <= max_seq_len) end_idx_unbacked_symint = shape_env.create_unbacked_symint() torch._check(end_idx_unbacked_symint >= 0) - torch._check(end_idx_unbacked_symint <= seq_len) + torch._check(end_idx_unbacked_symint <= max_seq_len) # Set the symbolic ints as the metadata for start_idx and end_idx inputs start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint @@ -123,7 +124,6 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten """ Insert slicing operations before each scaled_dot_product_attention operation. """ - pass # Find all nodes with scaled_dot_product_attention sdpa_nodes = [] for node in gm.graph.nodes: @@ -251,6 +251,7 @@ def insert_kv_cache( logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) gm = clean_up_graph_after_modifications(gm) + new_output_tensors = create_random_output_tensors(logits_keys_values) new_out_spec = pytree.tree_flatten(new_output_tensors)[1] diff --git a/examples/dynamo/test_static_cache.py b/examples/dynamo/test_static_cache.py index 35c89fdc01..645bfcccba 100644 --- a/examples/dynamo/test_static_cache.py +++ b/examples/dynamo/test_static_cache.py @@ -1,6 +1,21 @@ import torch import torch.nn as nn from torch.export import export +import torch_tensorrt +from contextlib import nullcontext +import argparse +from lower_sdpa import * +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) + +ATOL = 1e-8 +RTOL = 1e-5 class DynamicCacheModel(nn.Module): def __init__(self): @@ -69,13 +84,13 @@ def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value -def print_diff(tensor1, tensor2): +def print_diff(tensor1, tensor2, prefix=""): """ Print the diff between two tensors """ - print(f"[Diff] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") -def test_static_cache(): +def test_static_cache_model(args): """ Test the static cache model """ @@ -93,7 +108,8 @@ def test_static_cache(): end_idx = 2048 out_no_cache = model_no_cache(q, k, v) out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) - print(f"[Prefill] Diff between no cache and static cache: {torch.mean(torch.abs(out_no_cache - out_static_cache))}") + # print_diff(out_no_cache, out_static_cache, "Prefill") + torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) # Test Generate for start_idx in range(2048, 2176): @@ -109,61 +125,209 @@ def test_static_cache(): out_no_cache = model_no_cache(q_full, k_full, v_full) out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False) - print_diff(out_no_cache[:, :, -1:, :], out_static_cache) + # print_diff(out_no_cache[:, :, -1:, :], out_static_cache, f"Generate {start_idx}") + assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL) q = q_full k = k_full v = v_full + print("============== test_static_cache passed ==============") + +def test_static_cache_lowering(args): + """ + Test static cache lowering pass applied to the model with no cache and run the graph module + and compare the output with the model with no cache + """ + import static_cache + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), + strict=False + ) + # Post lower the model + settings = torch_tensorrt.dynamo.conversion.CompilationSettings( + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + out_no_cache = model_no_cache(q, k, v) + out_pyt_cache, key_cache, value_cache = gm(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx) + assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + is_causal = False + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, is_causal, key_cache, value_cache, start_idx, end_idx) + assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL) + q = q_full + k = k_full + v = v_full + + # Test Prefill with torch_tensorrt + q = torch.randn(1, 32, 2, 64).cuda() + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), + strict=False + ) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + out_no_cache = model_no_cache(q, k, v) + out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx) + print_diff(out_no_cache, out_trt, "Prefill TRT") + breakpoint() + # print_diff(trt_key_cache[:, :, :end_idx, :], k, "Prefill TRT key_cache") + # print_diff(trt_value_cache[:, :, :end_idx, :], v, "Prefill TRT value_cache") + assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL) + breakpoint() + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + + is_causal = False + out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q_curr, k_curr, v_curr, pyt_key_cache, pyt_value_cache, start_idx, end_idx, is_causal) + out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) + assert torch.allclose(out_static_cache, out_trt, atol=ATOL, rtol=RTOL) + + print_diff(out_static_cache, out_trt, f"Generate TRT {start_idx}") + # breakpoint() + assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) + assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) + + print("============== test_static_cache_with_torch_tensorrt passed ==============") +# def test_static_cache_with_torch_tensorrt(args): +# """ +# Test the static cache model with torch_tensorrt +# """ +# model_no_cache = ModelNoCache().eval().cuda() +# model_static_cache = StaticCacheModel().eval().cuda() +# q = torch.randn(1, 32, 2048, 64).cuda() +# k = torch.randn(1, 32, 2048, 64).cuda() +# v = torch.randn(1, 32, 2048, 64).cuda() +# key_cache = torch.zeros(1, 32, 2176, 64).cuda() +# value_cache = torch.zeros(1, 32, 2176, 64).cuda() + +# # Test Prefill +# start_idx = 0 +# end_idx = 2048 +# is_causal = True +# out_no_cache = model_no_cache(q, k, v) +# out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) +# seq_len = torch.export.Dim("seq_len", min=2, max=2048) +# seq_len_dyn_dim = {2 : seq_len} +# exported_program = export( +# model_static_cache, +# args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), +# dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None), +# strict=False +# ) + +# with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): +# trt_model = torch_tensorrt.dynamo.compile( +# exported_program, +# inputs=[q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal], +# enabled_precisions={torch.float32}, +# disable_tf32=True, +# use_python_runtime=True, +# debug=args.debug, +# min_block_size=1, +# ) +# out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) +# print_diff(out_static_cache, out_trt, "Prefill TRT") + +# assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL) +# assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) +# assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) + +# # Test Generate +# for start_idx in range(2048, 2176): +# end_idx = start_idx + 1 +# q_curr = torch.randn(1, 32, 1, 64).cuda() +# k_curr = torch.randn(1, 32, 1, 64).cuda() +# v_curr = torch.randn(1, 32, 1, 64).cuda() + +# is_causal = False +# out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q_curr, k_curr, v_curr, pyt_key_cache, pyt_value_cache, start_idx, end_idx, is_causal) +# out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) +# assert torch.allclose(out_static_cache, out_trt, atol=ATOL, rtol=RTOL) + +# print_diff(out_static_cache, out_trt, f"Generate TRT {start_idx}") +# # breakpoint() +# assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) +# assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) + +# print("============== test_static_cache_with_torch_tensorrt passed ==============") + def main(): - # Create model - # model = ConditionalModel2() - # model.eval() # Set to evaluation mode + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() with torch.inference_mode(): - test_static_cache() - # # Create example inputs - # q = torch.randn(1, 32, 2048, 64).cuda() - # k = torch.randn(1, 32, 2048, 64).cuda() - # v = torch.randn(1, 32, 2048, 64).cuda() - # k1 = torch.zeros(1, 32, 2176, 64).cuda() - # v1 = torch.zeros(1, 32, 2176, 64).cuda() - # # example_flag = torch.tensor(True) - # start_idx = 0 - # end_idx = 2048 - # out_pyt = model(q, k, v, k1, v1, start_idx, end_idx) - # out_pyt2 = model(q, k, v, k1, v1, 17, 18) - - # exported_program = export( - # model, - # args=(q, k, v, k1, v1, start_idx, end_idx), - # dynamic_shapes=(None, None, None, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), - # strict=False - # ) - # import torch_tensorrt - # with torch_tensorrt.logging.debug(): - # trt_model = torch_tensorrt.dynamo.compile( - # exported_program, - # inputs=[q, k, v, k1, v1, start_idx, end_idx], - # enabled_precisions={torch.float32}, - # # truncate_double=True, - # disable_tf32=True, - # use_python_runtime=True, - # debug=True, - # min_block_size=1, - # ) - - # gm = exported_program.module() - # breakpoint() - # out_ep = gm(q, k, v, k1, v1, start_idx, end_idx) - # out_ep2 = gm(q, k, v, k1, v1, 2048, 2049) - # out_trt = trt_model(q, k, v, k1, v1, start_idx, end_idx) - # out_trt2 = trt_model(q, k, v, k1, v1, 2048, 2049) - # # breakpoint() - # # Print the graph - # print("\nExported Graph:") - # print(exported_program.graph_module.graph) - # # breakpoint() - # # print("done") + # test_static_cache_model(args) + test_static_cache_lowering(args) if __name__ == "__main__": diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 26c7c4b605..664428cbe5 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -86,7 +86,6 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr # TODO: Handle batch in this check if not benchmark and stopping_criteria(input_seq, logits).item(): break - return input_seq def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id): @@ -99,10 +98,10 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL logits_concat = [] num_tokens_generated = 0 - kv_cache = get_zeroed_kv_cache_inputs(model) while end_idx < max_output_seq_length: is_causal = True if input_seq.shape[1] > 1 else False + # breakpoint() input_signature = (input_seq, is_causal, *kv_cache, start_idx, end_idx) logits_keys_values = model(*input_signature) num_tokens_generated += 1 @@ -115,6 +114,7 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id input_seq = next_tokens start_idx = end_idx end_idx = start_idx + 1 + lkv = torch.cat(logits_concat, dim=1) return output_seq From f539b55b4f2238e0f017ad3dd649e7bf7f171a9b Mon Sep 17 00:00:00 2001 From: Chengzhe Xu Date: Tue, 27 May 2025 19:13:53 +0000 Subject: [PATCH 10/30] chore: updates --- examples/dynamo/cache_utils.py | 152 +++++++++++++++ examples/dynamo/dynamic_cache.py | 45 ++--- examples/dynamo/llama3_trt.py | 70 ++++++- examples/dynamo/lower_sdpa.py | 20 +- examples/dynamo/sdpa_converter.py | 128 +++++-------- examples/dynamo/static_cache.py | 10 +- examples/dynamo/static_cache2.py | 274 +++++++++++++++++++++++++++ examples/dynamo/test_sdpa.py | 1 - examples/dynamo/test_static_cache.py | 213 +++++++++++---------- examples/dynamo/utils.py | 3 +- 10 files changed, 676 insertions(+), 240 deletions(-) create mode 100644 examples/dynamo/cache_utils.py create mode 100644 examples/dynamo/static_cache2.py diff --git a/examples/dynamo/cache_utils.py b/examples/dynamo/cache_utils.py new file mode 100644 index 0000000000..7a17ad7e65 --- /dev/null +++ b/examples/dynamo/cache_utils.py @@ -0,0 +1,152 @@ +import torch +from torch.fx import Graph, GraphModule, Node +from typing import Optional, Union, Iterable, List, Tuple +from torch._ops import OpOverloadPacket +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._pytree import _LEAF_SPEC +from torch._export.utils import _detect_fake_mode_from_gm + +def get_kv_nodes(gm): + """ + Get the key and value nodes from the graph. + """ + kv_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + q_node, k_node, v_node = node.args[:3] + kv_nodes.append((k_node, v_node)) + return kv_nodes + +def get_random_tensor_from_node(node: Node) -> torch.Tensor: + """ + Creates a random tensor based on the shape information in a node's metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + node: A torch.fx.Node object with metadata containing tensor information + + Returns: + A random tensor with shape matching the node's metadata, or None if no valid + tensor information is found + """ + if "val" not in node.meta: + raise ValueError(f"No tensor information found in node metadata for node: {node}") + + fake_tensor = node.meta["val"] + shape = [] + + # Iterate through each dimension and handle symbolic dimensions + for dim in fake_tensor.shape: + if isinstance(dim, torch.SymInt): + # Extract the maximum value from the shape environment + max_val = dim.node.hint + shape.append(max_val) + else: + shape.append(dim) + + # Create a random tensor with the determined shape + dtype = fake_tensor.dtype + device = fake_tensor.device + random_tensor = torch.rand(shape, dtype=dtype, device=device) + + return random_tensor + +def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: + """ + Creates random tensors based on the shape information in node metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + nodes: List of torch.fx.Node objects with metadata + + Returns: + List of random tensors with shapes matching the nodes' metadata + """ + random_tensors = [] + + for node in nodes: + if isinstance(node, Node): + node_tensor = get_random_tensor_from_node(node) + elif isinstance(node, tuple): + node_tensor_list = [] + for n in node: + random_tensor = get_random_tensor_from_node(n) + node_tensor_list.append(random_tensor) + node_tensor = tuple(node_tensor_list) + + random_tensors.append(node_tensor) + + return random_tensors + +def add_graph_input( + gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None +) -> Node: + """Add a graph input to the given GraphModule and return the newly created node. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the input to. + name (str): The name of the input. + val (torch.Tensor): An example tensor to use for the input. + dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] + """ + # check that no dynamic shape is provided... + if dynamic_shape: + raise NotImplementedError("Dynamic shape not supported for adding graph inputs") + + # extract graph and input spec + graph: Graph = gm.graph + + in_spec = graph._codegen.pytree_info.in_spec + in_spec_for_args = in_spec.children_specs[0] + orig_args = graph._codegen.pytree_info.orig_args + assert in_spec_for_args.type is tuple + + # insert input node after currently last input node + node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] + with graph.inserting_after(node_last_input): + in_node = graph.placeholder(name) + in_spec_for_args.children_specs.append(_LEAF_SPEC) + orig_args.append(f"arg_{name}") + + # update pytree info recursively with __post_init__ starting at leaves + def call_post_init(spec): + for child_spec in spec.children_specs: + call_post_init(child_spec) + spec.__post_init__() + + call_post_init(in_spec) + + # set fake tensor information if all required information is available + fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) + if fake_mode and val is not None and isinstance(val, torch.Tensor): + if isinstance(val, FakeTensor): + fake_tensor = val + else: + fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) + in_node.meta["val"] = fake_tensor + in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) + + # return new node... + return in_node + +def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: + """Check if the node is a call to one of the ops.""" + if node.op != "call_function": + return False + # check if it's a single op that's provided + if isinstance(ops, OpOverloadPacket): + ops = [ops] + + # check if it's the op itself instead of an overload + if any(node.target == op for op in ops): + return True + + return False + +def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: + input_nodes: List[Node] = graph.find_nodes(op="placeholder") + output_nodes: List[Node] = graph.find_nodes(op="output") + return (input_nodes, output_nodes) \ No newline at end of file diff --git a/examples/dynamo/dynamic_cache.py b/examples/dynamo/dynamic_cache.py index c678bac454..3727c4719f 100644 --- a/examples/dynamo/dynamic_cache.py +++ b/examples/dynamo/dynamic_cache.py @@ -14,7 +14,7 @@ clean_up_graph_after_modifications, ) -from .cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op import tensorrt import torch.utils._pytree as pytree logger = logging.getLogger(__name__) @@ -146,23 +146,7 @@ def get_static_tensor(tensor: torch.Tensor): v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) kv_inputs.append((k_input, v_input)) - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx") - end_idx_input = add_graph_input(gm, "end_idx") - return kv_inputs, start_idx_input, end_idx_input - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Insert slicing operations before each scaled_dot_product_attention operation. - """ - pass - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - sdpa_nodes.append(node) - - for idx, sdpa_node in enumerate(sdpa_nodes): + return kv_inputs def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): @@ -181,24 +165,27 @@ def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: sdpa_nodes.append(node) + # Get the is_causal input node + is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None) + # For each SDPA node, insert a torch.cond operation before it for idx, sdpa_node in enumerate(sdpa_nodes): with gm.graph.inserting_before(sdpa_node): - pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) + # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) q_node, k_node, v_node = sdpa_node.args[:3] incoming_key, incoming_value = incoming_keys_values[idx] # Create nodes for concatenating k with incoming_key and v with incoming_value concatenated_k_node = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([k_node, incoming_key], 2), # Concatenate along sequence length dimension + args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension kwargs={} ) concatenated_v_node = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([v_node, incoming_value], 2), # Concatenate along sequence length dimension + args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension kwargs={} ) @@ -206,16 +193,16 @@ def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten cond_k_node = gm.graph.create_node( "call_function", torch.ops.higher_order.cond, - args=(pred_node, concatenated_k_node, k_node), + args=(is_causal_node, concatenated_k_node, k_node), ) cond_v_node = gm.graph.create_node( "call_function", torch.ops.higher_order.cond, - args=(pred_node, concatenated_v_node, v_node), + args=(is_causal_node, concatenated_v_node, v_node), ) - sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] return gm @@ -229,13 +216,13 @@ def insert_dynamic_kv_cache( """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True) - # Call the function to add QKV as outputs - logits_keys_values = add_kv_as_outputs(gm, start_idx_input, end_idx_input) + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm) - gm = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - # gm = insert_torch_cond_before_sdpa(gm, kv_inputs) + # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases + gm = insert_torch_cond_before_sdpa(gm, kv_inputs) gm = clean_up_graph_after_modifications(gm) diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llama3_trt.py index 8ade81ccaa..d1e816906c 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llama3_trt.py @@ -19,7 +19,7 @@ import torch_tensorrt from transformers import AutoModelForCausalLM, AutoTokenizer from contextlib import nullcontext -from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache +from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs DEVICE = torch.device("cuda:0") @@ -43,7 +43,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=1 + num_hidden_layers=1 ) .eval() .cuda() @@ -194,9 +194,10 @@ def measure_perf(trt_model, input_signature, backend_name): help="Enable pytorch run (default: False)" ) arg_parser.add_argument( - "--kv_cache", - action="store_true", - help="Enable kv_cache (default: False)" + "--cache", + type=str, + default="static", + help="Type of KV cache to use", ) arg_parser.add_argument( "--cudagraph", @@ -220,9 +221,9 @@ def measure_perf(trt_model, input_signature, backend_name): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) prompt = "What is parallel programming ?" + # prompt = "What is the capital of France ?" model_inputs = tokenizer(prompt, return_tensors="pt") input_ids = model_inputs["input_ids"].to(DEVICE) - # Prepare input prompt # word = "What" # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence @@ -252,18 +253,67 @@ def measure_perf(trt_model, input_signature, backend_name): ) # TRT + pyt_logits_tok1 = model.cuda()(input_ids) + next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1) + input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + pyt_logits_tok2 = model.cuda()(input_seq) from lower_sdpa import * - if args.kv_cache: - # This import is required to register static/dynamic KV cache transformations as lowering passes - from static_cache import * + if args.cache == "static": + # This import is required to register static KV cache transformations as lowering passes + from static_cache2 import * + trt_model = compile_torchtrt(model, input_ids, args) + kv_cache = get_zeroed_kv_cache_inputs(trt_model) + + # First token generation + pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt") + trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1]) + print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}") + print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}") + print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}") + print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}") + print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}") + next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1) + + # Second token generation + trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1) + pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt") + print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}") + print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}") + print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}") + print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}") + print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}") + breakpoint() + elif args.cache == "dynamic": + from dynamic_cache import * trt_model = compile_torchtrt(model, input_ids, args) + breakpoint() + kv_cache = get_zeroed_kv_cache_inputs(trt_model) else: # pyt_logits = model.cuda()(input_ids.clone()) trt_model = compile_torchtrt(model, input_ids, args) # trt_logits = trt_model(input_ids.clone(), True) # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") - if args.kv_cache: + if args.cache == "static": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_kv_cache( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_kv_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + elif args.cache == "dynamic": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) diff --git a/examples/dynamo/lower_sdpa.py b/examples/dynamo/lower_sdpa.py index c71b168e6f..c92d707391 100644 --- a/examples/dynamo/lower_sdpa.py +++ b/examples/dynamo/lower_sdpa.py @@ -28,26 +28,24 @@ def replace_variants_of_sdpa( """Replace scaled_dot_product_attention with an equivalent implementation which can be accurately converted to TRT """ - # If sdpa replacement is found, add is_causal_input only once in the graph - is_causal_input = None + for node in gm.graph.nodes: if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - # is_causal_input is None if this is the first sdpa node in the graph, otherwise it is reused across all sdpa nodes - if is_causal_input is None: - # Add a new input to the graph for is_causal - is_causal_input = add_graph_input(gm, "is_causal", True) - is_causal_input.meta["val"] = torch.tensor(True) - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: if len(node.args) == 7: query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal = node.args elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 + else: + raise ValueError(f"Unexpected number of arguments for {node.target} in the graph") elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: query, key, value, dropout_p, is_causal, return_debug_mask = node.args - - modified_input_args = (query, key, value, None, dropout_p, is_causal_input) + + if attn_mask is not None: + logger.warning(f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration.") + breakpoint() + modified_input_args = (query, key, value, None, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale @@ -75,6 +73,6 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) - + logger.info("Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention") return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 814b9fe26b..71f1ed9795 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -22,29 +22,17 @@ def tril( target: Union[Target, str], source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + row: TRTTensor, + col: TRTTensor, ) -> TRTTensor: - # the lower triangle of the tensor means the rows greater than and equal to the cols - row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) - col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) - rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 - ) - # get the rows - row_tensor = impl.elementwise.trunc_div( - ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col - ) - # get the cols - col_tensor = impl.elementwise.fmod( - ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col - ) - cond = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_tensor, col_tensor - ) - return impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", cond, [row, col] - ) + row_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1) + row_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]) + + col_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1) + col_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]) + + mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor) + return mask @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.nn.functional.scaled_dot_product_attention, enabled=True, supports_dynamic_shapes=True) @@ -103,71 +91,51 @@ def scaled_dot_product_attention( ) # If is_causal is True, we need to generate a causal mask - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, -2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) - - LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) - - # this is to generate a tensor which has shape (L, S), type is int32 - arange_tensor = impl.arange.arange( - ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 - ) - shape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] - ) + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - # since we want our attn_bias to be in float32, so cast it to float32 - shape_tensor = cast_trt_tensor( - ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir - ) + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - # initialize the attn_bias as the zeros tensor - attn_bias = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 - ) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - inf_tensor = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") - ) - cond = impl.elementwise.eq( - ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) - ) - # mask out the certain part of the attn_bias - attn_bias = impl.condition.select( - ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask_casted = cast_trt_tensor(ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir) + one_minus_temp_mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_one_minus_temp_mask", 1.0, temp_mask_casted + ) + attn_bias = impl.unary.log(ctx, target, source_ir, name + "_log", one_minus_temp_mask) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias - ) + else: + scaled_add_attn_bias = scaled # Create a if condition to check if is_causal is True - # if_layer = ctx.net.add_if_conditional() - # condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled - # if_layer.set_condition(condition) - # output_layer = if_layer.add_output(true_branch, false_branch) - # attn_weights = output_layer.get_output(0) + if isinstance(is_causal, TRTTensor): + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + attn_weights = output_layer.get_output(0) softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ctx, target, source_ir, name + "_softmax", attn_weights, -1, False ) out = impl.matmul.matrix_multiply( ctx, diff --git a/examples/dynamo/static_cache.py b/examples/dynamo/static_cache.py index 59b84eee86..ae59205aed 100644 --- a/examples/dynamo/static_cache.py +++ b/examples/dynamo/static_cache.py @@ -97,10 +97,10 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache, value_cache, .. + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] input_ids_meta = input_nodes[0].meta["val"] - seq_len = input_ids_meta.shape[2] + seq_len = input_ids_meta.shape[1] min_max_opt = extract_var_range_info(seq_len) max_seq_len = min_max_opt["max"] @@ -120,6 +120,8 @@ def get_static_tensor(tensor: torch.Tensor): return kv_inputs, start_idx_input, end_idx_input + + def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): """ Insert slicing operations before each scaled_dot_product_attention operation. @@ -247,13 +249,13 @@ def insert_kv_cache( # incoming keys and values from previous tokens (which were added as inputs) gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - # Call the function to add QKV as outputs + # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) gm = clean_up_graph_after_modifications(gm) new_output_tensors = create_random_output_tensors(logits_keys_values) - + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] gm._out_spec = new_out_spec logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) diff --git a/examples/dynamo/static_cache2.py b/examples/dynamo/static_cache2.py new file mode 100644 index 0000000000..6c6d0eb6b1 --- /dev/null +++ b/examples/dynamo/static_cache2.py @@ -0,0 +1,274 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + # Get the third last input which should be the last value cache node and store the max_seq_len + input_ids_meta = input_nodes[-3].meta["val"] + seq_len = input_ids_meta.shape[2] + + if isinstance(seq_len, torch.SymInt): + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + else: + max_seq_len = seq_len + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + +def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input): + """ + Create slicing and concatenation nodes for KV cache update. + + This function creates the necessary slicing and concatenation nodes to update the KV cache + during the generation process. It takes the SDPA node, the current KV cache node, and the + incoming KV cache node as input. + Returns: + for a particular SDPA node, a tuple containing: + - List of new current KV nodes + - List of updated incoming KV cache nodes + + """ + + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # Concat key_cache[:,:,:start_idx,:] with current key (k) + concat_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_kv_node], 2), + kwargs={} + ) + + # =============================================== # + # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + new_incoming_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([concat_keys_or_values, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + new_incoming_keys_or_values.meta.update(incoming_kv_node.meta) + + return concat_keys_or_values, new_incoming_keys_or_values + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + """ + Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + # For keys + new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input) + # For values + new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input) + + # Store the KV cache nodes for the current SDPA node + kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) + + # Update the SDPA node arguments with current key and value nodes + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + sdpa_node.args[3:] + + kv_cache_for_graph.extend([k_node, v_node]) + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm + + diff --git a/examples/dynamo/test_sdpa.py b/examples/dynamo/test_sdpa.py index c8e3811925..43c6e285b2 100644 --- a/examples/dynamo/test_sdpa.py +++ b/examples/dynamo/test_sdpa.py @@ -54,7 +54,6 @@ def forward(self, hidden_states, position_embeddings): disable_tf32=True, debug=True) trt_output = trt_model(hidden_states, position_embeddings, None) - breakpoint() if isinstance(pyt_output, tuple): print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") else: diff --git a/examples/dynamo/test_static_cache.py b/examples/dynamo/test_static_cache.py index 645bfcccba..080e346dd0 100644 --- a/examples/dynamo/test_static_cache.py +++ b/examples/dynamo/test_static_cache.py @@ -14,8 +14,10 @@ pre_export_lowering, ) -ATOL = 1e-8 +ATOL = 1e-5 RTOL = 1e-5 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False class DynamicCacheModel(nn.Module): def __init__(self): @@ -45,10 +47,19 @@ class StaticCacheModel(nn.Module): def __init__(self): super().__init__() + # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + # new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + # new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + # out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + + # return out, new_key_cache, new_value_cache + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): - new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) - new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) - out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ] + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) return out, new_key_cache, new_value_cache @@ -62,6 +73,7 @@ def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + breakpoint() if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() @@ -108,8 +120,7 @@ def test_static_cache_model(args): end_idx = 2048 out_no_cache = model_no_cache(q, k, v) out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) - # print_diff(out_no_cache, out_static_cache, "Prefill") - torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) + assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) # Test Generate for start_idx in range(2048, 2176): @@ -125,19 +136,42 @@ def test_static_cache_model(args): out_no_cache = model_no_cache(q_full, k_full, v_full) out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False) - # print_diff(out_no_cache[:, :, -1:, :], out_static_cache, f"Generate {start_idx}") + assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL) q = q_full k = k_full v = v_full print("============== test_static_cache passed ==============") +def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args): + """ + Transform the graph module by adding key and value cache to the graph + """ + gm = exported_program.module() + # Post lower the model + settings = torch_tensorrt.dynamo.conversion.CompilationSettings( + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + return gm + def test_static_cache_lowering(args): """ Test static cache lowering pass applied to the model with no cache and run the graph module and compare the output with the model with no cache """ - import static_cache + import static_cache2 model_no_cache = ModelNoCache().eval().cuda() q = torch.randn(1, 32, 2, 64).cuda() @@ -155,22 +189,9 @@ def test_static_cache_lowering(args): dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), strict=False ) - # Post lower the model - settings = torch_tensorrt.dynamo.conversion.CompilationSettings( - enabled_precisions={torch.float32}, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - min_block_size=1, - ) - exported_program = pre_export_lowering(exported_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) - ) - gm = exported_program.module() - gm = post_lowering(gm, settings) - + gm = transform_gm_with_kv_cache(exported_program, args) + # Test Prefill start_idx = 0 end_idx = 2048 @@ -198,16 +219,56 @@ def test_static_cache_lowering(args): q = q_full k = k_full v = v_full + + print("============== test_static_cache_lowering passed ==============") - # Test Prefill with torch_tensorrt +def test_static_cache_export(args): + """ + Test the static cache model export + """ + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + # Export the model + seq_len = torch.export.Dim("seq_len", min=2, max=2048) + seq_len_dyn_dim = {2 : seq_len} + exported_program = export( + model_static_cache, + args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), + dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None), + strict=False + ) + + +def test_static_cache_with_torch_tensorrt(args): + """ + Test the static cache model with torch_tensorrt + """ + import static_cache2 + + model_no_cache = ModelNoCache().eval().cuda() q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) exported_program = export( model_no_cache, args=(q, k, v), dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), strict=False ) - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile( exported_program, @@ -223,96 +284,39 @@ def test_static_cache_lowering(args): end_idx = 2048 is_causal = True q = torch.randn(1, 32, 2048, 64).cuda() + # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) out_no_cache = model_no_cache(q, k, v) out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx) - print_diff(out_no_cache, out_trt, "Prefill TRT") - breakpoint() - # print_diff(trt_key_cache[:, :, :end_idx, :], k, "Prefill TRT key_cache") - # print_diff(trt_value_cache[:, :, :end_idx, :], v, "Prefill TRT value_cache") - assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL) - breakpoint() - + # breakpoint() + assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match" + assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match" + assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match" + # Test Generate for start_idx in range(2048, 2176): end_idx = start_idx + 1 q_curr = torch.randn(1, 32, 1, 64).cuda() k_curr = torch.randn(1, 32, 1, 64).cuda() v_curr = torch.randn(1, 32, 1, 64).cuda() - + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) is_causal = False - out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q_curr, k_curr, v_curr, pyt_key_cache, pyt_value_cache, start_idx, end_idx, is_causal) - out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) - assert torch.allclose(out_static_cache, out_trt, atol=ATOL, rtol=RTOL) - - print_diff(out_static_cache, out_trt, f"Generate TRT {start_idx}") + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, is_causal, trt_key_cache, trt_value_cache, start_idx, end_idx) # breakpoint() - assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) - assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) + # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") + # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") + # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}") + assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}" + assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}" + assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}" + q = q_full + k = k_full + v = v_full print("============== test_static_cache_with_torch_tensorrt passed ==============") - -# def test_static_cache_with_torch_tensorrt(args): -# """ -# Test the static cache model with torch_tensorrt -# """ -# model_no_cache = ModelNoCache().eval().cuda() -# model_static_cache = StaticCacheModel().eval().cuda() -# q = torch.randn(1, 32, 2048, 64).cuda() -# k = torch.randn(1, 32, 2048, 64).cuda() -# v = torch.randn(1, 32, 2048, 64).cuda() -# key_cache = torch.zeros(1, 32, 2176, 64).cuda() -# value_cache = torch.zeros(1, 32, 2176, 64).cuda() - -# # Test Prefill -# start_idx = 0 -# end_idx = 2048 -# is_causal = True -# out_no_cache = model_no_cache(q, k, v) -# out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) -# seq_len = torch.export.Dim("seq_len", min=2, max=2048) -# seq_len_dyn_dim = {2 : seq_len} -# exported_program = export( -# model_static_cache, -# args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), -# dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None), -# strict=False -# ) - -# with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): -# trt_model = torch_tensorrt.dynamo.compile( -# exported_program, -# inputs=[q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal], -# enabled_precisions={torch.float32}, -# disable_tf32=True, -# use_python_runtime=True, -# debug=args.debug, -# min_block_size=1, -# ) -# out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) -# print_diff(out_static_cache, out_trt, "Prefill TRT") - -# assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL) -# assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) -# assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) - -# # Test Generate -# for start_idx in range(2048, 2176): -# end_idx = start_idx + 1 -# q_curr = torch.randn(1, 32, 1, 64).cuda() -# k_curr = torch.randn(1, 32, 1, 64).cuda() -# v_curr = torch.randn(1, 32, 1, 64).cuda() - -# is_causal = False -# out_static_cache, pyt_key_cache, pyt_value_cache = model_static_cache(q_curr, k_curr, v_curr, pyt_key_cache, pyt_value_cache, start_idx, end_idx, is_causal) -# out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) -# assert torch.allclose(out_static_cache, out_trt, atol=ATOL, rtol=RTOL) - -# print_diff(out_static_cache, out_trt, f"Generate TRT {start_idx}") -# # breakpoint() -# assert torch.allclose(pyt_key_cache, trt_key_cache, atol=ATOL, rtol=RTOL) -# assert torch.allclose(pyt_value_cache, trt_value_cache, atol=ATOL, rtol=RTOL) - -# print("============== test_static_cache_with_torch_tensorrt passed ==============") def main(): @@ -327,7 +331,8 @@ def main(): args = arg_parser.parse_args() with torch.inference_mode(): # test_static_cache_model(args) - test_static_cache_lowering(args) + # test_static_cache_lowering(args) + test_static_cache_with_torch_tensorrt(args) if __name__ == "__main__": diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 664428cbe5..6880d9126e 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -86,6 +86,7 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr # TODO: Handle batch in this check if not benchmark and stopping_criteria(input_seq, logits).item(): break + # breakpoint() return input_seq def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id): @@ -115,7 +116,7 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id start_idx = end_idx end_idx = start_idx + 1 lkv = torch.cat(logits_concat, dim=1) - + # breakpoint() return output_seq def time_generate( From 0dc3a7edc550aad927bcecca6815d5a6381b999b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 28 May 2025 23:42:25 +0000 Subject: [PATCH 11/30] chore: updates --- examples/dynamo/lower_sdpa.py | 2 +- examples/dynamo/static_cache2.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/dynamo/lower_sdpa.py b/examples/dynamo/lower_sdpa.py index c92d707391..9240a1f836 100644 --- a/examples/dynamo/lower_sdpa.py +++ b/examples/dynamo/lower_sdpa.py @@ -44,7 +44,7 @@ def replace_variants_of_sdpa( if attn_mask is not None: logger.warning(f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration.") - breakpoint() + modified_input_args = (query, key, value, None, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention diff --git a/examples/dynamo/static_cache2.py b/examples/dynamo/static_cache2.py index 6c6d0eb6b1..65070def06 100644 --- a/examples/dynamo/static_cache2.py +++ b/examples/dynamo/static_cache2.py @@ -123,7 +123,11 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint - return kv_inputs, start_idx_input, end_idx_input + # Add is_causal as input + is_causal_input = add_graph_input(gm, "is_causal", True) + is_causal_input.meta["val"] = torch.tensor(True) + + return kv_inputs, start_idx_input, end_idx_input, is_causal_input def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input): """ @@ -212,7 +216,7 @@ def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_nod return concat_keys_or_values, new_incoming_keys_or_values -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): """ Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) @@ -239,7 +243,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) # Update the SDPA node arguments with current key and value nodes - sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + sdpa_node.args[3:] + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:] kv_cache_for_graph.extend([k_node, v_node]) return gm, kv_cache_for_graph @@ -252,11 +256,11 @@ def insert_kv_cache( """Insert KV cache ops in the graph""" """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) # Build and update the KV cache using computed KV inputs for current token and # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) From 600e363db9c2ebe605292347f0bba4ab10fca0cf Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 31 May 2025 02:00:05 +0000 Subject: [PATCH 12/30] feat: Refactor LLM runner and implemented support for Qwen family --- examples/dynamo/{ => llm}/cache_utils.py | 0 examples/dynamo/{ => llm}/dynamic_cache.py | 0 .../dynamo/{llama3_trt.py => llm/run_llm.py} | 196 ++++------ .../static_cache_v1.py} | 17 +- .../static_cache_v2.py} | 9 +- examples/dynamo/llm/test_llama_components.py | 336 ++++++++++++++++++ .../dynamo/llm/test_qwen2.5_components.py | 174 +++++++++ .../dynamo/{ => llm}/test_static_cache.py | 16 +- examples/dynamo/{ => llm}/utils.py | 24 +- examples/dynamo/register_sdpa.py | 17 +- examples/dynamo/sdpa_converter.py | 71 ++-- examples/dynamo/test_sdpa.py | 109 ------ .../dynamo/lowering/_decompositions.py | 24 +- 13 files changed, 681 insertions(+), 312 deletions(-) rename examples/dynamo/{ => llm}/cache_utils.py (100%) rename examples/dynamo/{ => llm}/dynamic_cache.py (100%) rename examples/dynamo/{llama3_trt.py => llm/run_llm.py} (53%) rename examples/dynamo/{static_cache.py => llm/static_cache_v1.py} (94%) rename examples/dynamo/{static_cache2.py => llm/static_cache_v2.py} (96%) create mode 100644 examples/dynamo/llm/test_llama_components.py create mode 100644 examples/dynamo/llm/test_qwen2.5_components.py rename examples/dynamo/{ => llm}/test_static_cache.py (96%) rename examples/dynamo/{ => llm}/utils.py (85%) delete mode 100644 examples/dynamo/test_sdpa.py diff --git a/examples/dynamo/cache_utils.py b/examples/dynamo/llm/cache_utils.py similarity index 100% rename from examples/dynamo/cache_utils.py rename to examples/dynamo/llm/cache_utils.py diff --git a/examples/dynamo/dynamic_cache.py b/examples/dynamo/llm/dynamic_cache.py similarity index 100% rename from examples/dynamo/dynamic_cache.py rename to examples/dynamo/llm/dynamic_cache.py diff --git a/examples/dynamo/llama3_trt.py b/examples/dynamo/llm/run_llm.py similarity index 53% rename from examples/dynamo/llama3_trt.py rename to examples/dynamo/llm/run_llm.py index d1e816906c..942e5a621b 100644 --- a/examples/dynamo/llama3_trt.py +++ b/examples/dynamo/llm/run_llm.py @@ -19,63 +19,29 @@ import torch_tensorrt from transformers import AutoModelForCausalLM, AutoTokenizer from contextlib import nullcontext -from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs +from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache +import sys +import os +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * DEVICE = torch.device("cuda:0") def get_model(args): with torch.no_grad(): - if args.model == "meta-llama/Llama-2-7b-chat-hf": - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1 - ) - .eval() - .cuda() - ) - elif args.model == "meta-llama/Llama-3.2-1B-Instruct": - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1 - ) - .eval() - .cuda() - ) - - elif args.model == "meta-llama/Llama-3.2-3B-Instruct": - model = ( + # Supported list of models: + # - meta-llama/Llama-3.2-1B-Instruct + # - meta-llama/Llama-3.2-3B-Instruct + # - meta-llama/Llama-3.1-8B-Instruct + # - Qwen/Qwen2.5-1.5B-Instruct + model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=2 - ) - .eval() - .cuda() - ) - elif args.model == "meta-llama/Llama-3.1-8B-Instruct": - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", # num_hidden_layers=1 - ) - .eval() - .cuda() - ) - elif args.model == "google/gemma-3-1b-it": - model = ( - AutoModelForCausalLM.from_pretrained( - "google/gemma-3-1b-it", - use_cache=False, - attn_implementation="sdpa" + # num_hidden_layers=1 ) .eval() .cuda() @@ -91,9 +57,9 @@ def get_model(args): def compile_torchtrt(model, input_ids, args): - max_seq_len = input_ids.shape[1] + args.max_tokens + max_seq_len = input_ids.shape[1] + args.num_tokens ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - + # Set precision specific flags use_fp32_acc = False use_explicit_typing = False @@ -119,6 +85,7 @@ def compile_torchtrt(model, input_ids, args): disable_tf32=True, use_python_runtime=True, debug=args.debug, + offload_module_to_cpu=True, min_block_size=args.min_block_size, ) @@ -170,15 +137,15 @@ def measure_perf(trt_model, input_signature, backend_name): "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model" ) arg_parser.add_argument( - "--tokenizer_path", + "--tokenizer", type=str, - default="meta-llama/Llama-3.2-1B-Instruct", + default="", help="Name of LLM model tokenizer", ) arg_parser.add_argument( "--prompt", type=str, default="What is parallel programming ?", help="Prompt" ) - arg_parser.add_argument("--precision", type=str, default="FP16", help="Prompt") + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") arg_parser.add_argument( "--iterations", type=int, default=5, help="no. of iterations to run" ) @@ -186,7 +153,13 @@ def measure_perf(trt_model, input_signature, backend_name): "--min_block_size", type=int, default=1, help="no. of iterations to run" ) arg_parser.add_argument( - "--max_tokens", type=int, default=128, help="no. of max tokens to be generated" + "--num_tokens", type=int, default=128, help="no. of output tokens to be generated" + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", type=int, default=2048, help="Input sequence length used for benchmarking" ) arg_parser.add_argument( "--enable_pytorch_run", @@ -196,8 +169,8 @@ def measure_perf(trt_model, input_signature, backend_name): arg_parser.add_argument( "--cache", type=str, - default="static", - help="Type of KV cache to use", + default="", + help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", ) arg_parser.add_argument( "--cudagraph", @@ -214,22 +187,24 @@ def measure_perf(trt_model, input_signature, backend_name): action="store_true", help="Enable benchmark (default: False)" ) + args = arg_parser.parse_args() with torch.inference_mode(): model = get_model(args) - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) - prompt = "What is parallel programming ?" - # prompt = "What is the capital of France ?" - model_inputs = tokenizer(prompt, return_tensors="pt") - input_ids = model_inputs["input_ids"].to(DEVICE) - # Prepare input prompt - # word = "What" - # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence - # input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.max_tokens + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens # Pyt pyt_gen_tokens = None pyt_timings = None @@ -238,7 +213,6 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_gen_tokens = generate( model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id ) - if args.benchmark: pyt_timings = time_generate( generate, @@ -249,71 +223,22 @@ def measure_perf(trt_model, input_signature, backend_name): iterations=args.iterations, ) pyt_stats = recordStats( - "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None + "PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None ) - # TRT - pyt_logits_tok1 = model.cuda()(input_ids) - next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1) - input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - pyt_logits_tok2 = model.cuda()(input_seq) - from lower_sdpa import * - if args.cache == "static": - # This import is required to register static KV cache transformations as lowering passes - from static_cache2 import * - trt_model = compile_torchtrt(model, input_ids, args) - kv_cache = get_zeroed_kv_cache_inputs(trt_model) - - # First token generation - pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt") - trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1]) - print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}") - print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}") - print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}") - print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}") - print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}") - next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1) - - # Second token generation - trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1) - pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt") - print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}") - print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}") - print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}") - print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}") - print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}") - breakpoint() + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 elif args.cache == "dynamic": - from dynamic_cache import * - trt_model = compile_torchtrt(model, input_ids, args) - breakpoint() - kv_cache = get_zeroed_kv_cache_inputs(trt_model) - else: - # pyt_logits = model.cuda()(input_ids.clone()) - trt_model = compile_torchtrt(model, input_ids, args) - # trt_logits = trt_model(input_ids.clone(), True) - # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}") - # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}") - if args.cache == "static": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - trt_gen_tokens = generate_with_kv_cache( - trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - ) + import dynamic_cache - if args.benchmark: - trt_timings = time_generate( - generate_with_kv_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - elif args.cache == "dynamic": + + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "dynamic": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) @@ -332,7 +257,6 @@ def measure_perf(trt_model, input_signature, backend_name): tokenizer.eos_token_id, iterations=args.iterations, ) - else: trt_gen_tokens = generate( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, @@ -349,14 +273,20 @@ def measure_perf(trt_model, input_signature, backend_name): if args.benchmark: trt_stats = recordStats( - "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None + "TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None ) - if args.enable_pytorch_run: - print_outputs("PyTorch", pyt_gen_tokens, tokenizer) - print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) - if args.benchmark: + if args.enable_pytorch_run: + print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}") + + if args.benchmark: if args.enable_pytorch_run: print("=========PyTorch PERFORMANCE============ \n") print(pyt_stats) diff --git a/examples/dynamo/static_cache.py b/examples/dynamo/llm/static_cache_v1.py similarity index 94% rename from examples/dynamo/static_cache.py rename to examples/dynamo/llm/static_cache_v1.py index ae59205aed..8c278f2fb6 100644 --- a/examples/dynamo/static_cache.py +++ b/examples/dynamo/llm/static_cache_v1.py @@ -118,11 +118,15 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint - return kv_inputs, start_idx_input, end_idx_input + # Add is_causal as input + is_causal_input = add_graph_input(gm, "is_causal", True) + is_causal_input.meta["val"] = torch.tensor(True) + return kv_inputs, start_idx_input, end_idx_input, is_causal_input -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): """ Insert slicing operations before each scaled_dot_product_attention operation. """ @@ -133,7 +137,8 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten sdpa_nodes.append(node) kv_cache_for_graph = [] for idx, sdpa_node in enumerate(sdpa_nodes): - q_node, k_node, v_node = sdpa_node.args[:3] + assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args incoming_key, incoming_value = incoming_keys_values[idx] kv_cache_for_sdpa_node = [] new_keys_values = [] @@ -231,7 +236,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten kv_cache_for_graph.extend(kv_cache_for_sdpa_node) - sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + sdpa_node.args[3:] + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, is_causal_input) return gm, kv_cache_for_graph @@ -243,11 +248,11 @@ def insert_kv_cache( """Insert KV cache ops in the graph""" """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) # Build and update the KV cache using computed KV inputs for current token and # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) diff --git a/examples/dynamo/static_cache2.py b/examples/dynamo/llm/static_cache_v2.py similarity index 96% rename from examples/dynamo/static_cache2.py rename to examples/dynamo/llm/static_cache_v2.py index 65070def06..e15d6737b1 100644 --- a/examples/dynamo/static_cache2.py +++ b/examples/dynamo/llm/static_cache_v2.py @@ -97,7 +97,7 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] # Get the third last input which should be the last value cache node and store the max_seq_len input_ids_meta = input_nodes[-3].meta["val"] @@ -232,7 +232,8 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten sdpa_nodes.append(node) kv_cache_for_graph = [] for idx, sdpa_node in enumerate(sdpa_nodes): - q_node, k_node, v_node = sdpa_node.args[:3] + assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args incoming_key, incoming_value = incoming_keys_values[idx] # For keys new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input) @@ -243,9 +244,9 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) # Update the SDPA node arguments with current key and value nodes - sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:] + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, is_causal_input) - kv_cache_for_graph.extend([k_node, v_node]) + # kv_cache_for_graph.extend([k_node, v_node]) return gm, kv_cache_for_graph diff --git a/examples/dynamo/llm/test_llama_components.py b/examples/dynamo/llm/test_llama_components.py new file mode 100644 index 0000000000..9a57629d97 --- /dev/null +++ b/examples/dynamo/llm/test_llama_components.py @@ -0,0 +1,336 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * +ATOL = 1e-5 +RTOL = 1e-5 + + +# llama2_model_name = "meta-llama/Llama-2-7b-hf" +llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" +llama_model = AutoModelForCausalLM.from_pretrained( + llama3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +LLAMA_CONFIG = llama_model.config + +def test_llama_attention(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention( + config=self.config, + layer_idx=0 + ) + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) + return attn_output + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + model = llama_model.model.layers[0].self_attn.to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + if isinstance(pyt_output, tuple): + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + breakpoint() + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_llama_attention_with_static_cache(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention( + config=self.config, + layer_idx=0 + ) + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) + return attn_output + + DTYPE = torch.float32 + model = llama_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + import register_sdpa + import static_cache2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_decoder(args): + + DTYPE = torch.float32 + model = llama_model.model.layers[0].to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_llama_decoder_with_static_cache(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + import register_sdpa + import static_cache2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + +def test_llama_model_with_static_cache(args): + + DTYPE = torch.float32 + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(input_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids} + ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes) + + import register_sdpa + import static_cache2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal) + pyt_output = pyt_output.last_hidden_state + print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda() + position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda() + + # Concatenate the current hidden_states with the previous ones + input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1) + position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1) + is_causal = False + kwarg_inputs = {"input_ids":input_ids_full, "position_ids":position_ids_full} + out_no_cache = model(**kwarg_inputs) + + out_trt, key_cache, value_cache = trt_model(input_ids_curr, position_ids_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache.last_hidden_state[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + input_ids = input_ids_full + position_ids = position_ids_full + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision (default: FP16)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + test_llama_attention(args) + # test_llama_decoder(args) + # test_llama_attention_with_static_cache(args) + # test_llama_decoder_with_static_cache(args) + # test_llama_model_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_qwen2.5_components.py b/examples/dynamo/llm/test_qwen2.5_components.py new file mode 100644 index 0000000000..7921c9622d --- /dev/null +++ b/examples/dynamo/llm/test_qwen2.5_components.py @@ -0,0 +1,174 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct" +qwen2_5_model = AutoModelForCausalLM.from_pretrained( + qwen2_5_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +QWEN_CONFIG = qwen2_5_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_qwen_apply_rotary_pos_emb(args): + class QwenApplyRotaryPosEmb(nn.Module): + def __init__(self): + super().__init__() + + def rotate_half(self, x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, q, k, cos, sin, unsqueeze_dim=1): + return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE) + # Shapes for Qwen 2.5 + q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + + pyt_output = model(q, k, cos, sin) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[q, k, cos, sin], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(q, k, cos, sin) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_apply_rotary_pos_emb(args) + test_qwen_attention(args) diff --git a/examples/dynamo/test_static_cache.py b/examples/dynamo/llm/test_static_cache.py similarity index 96% rename from examples/dynamo/test_static_cache.py rename to examples/dynamo/llm/test_static_cache.py index 080e346dd0..538cbc3f34 100644 --- a/examples/dynamo/test_static_cache.py +++ b/examples/dynamo/llm/test_static_cache.py @@ -4,7 +4,7 @@ import torch_tensorrt from contextlib import nullcontext import argparse -from lower_sdpa import * +import register_sdpa from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer from transformers.models.llama.configuration_llama import LlamaConfig from transformers import AutoModelForCausalLM @@ -198,7 +198,7 @@ def test_static_cache_lowering(args): is_causal = True q = torch.randn(1, 32, 2048, 64).cuda() out_no_cache = model_no_cache(q, k, v) - out_pyt_cache, key_cache, value_cache = gm(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx) + out_pyt_cache, key_cache, value_cache = gm(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) # Test Generate @@ -214,7 +214,7 @@ def test_static_cache_lowering(args): v_full = torch.cat((v, v_curr), dim=2) out_no_cache = model_no_cache(q_full, k_full, v_full) - out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, is_causal, key_cache, value_cache, start_idx, end_idx) + out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, key_cache, value_cache, start_idx, end_idx, is_causal) assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL) q = q_full k = k_full @@ -286,8 +286,8 @@ def test_static_cache_with_torch_tensorrt(args): q = torch.randn(1, 32, 2048, 64).cuda() # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) out_no_cache = model_no_cache(q, k, v) - out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx) - # breakpoint() + out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) + assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match" assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match" assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match" @@ -304,7 +304,7 @@ def test_static_cache_with_torch_tensorrt(args): v_full = torch.cat((v, v_curr), dim=2) is_causal = False out_no_cache = model_no_cache(q_full, k_full, v_full) - out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, is_causal, trt_key_cache, trt_value_cache, start_idx, end_idx) + out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) # breakpoint() # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") @@ -330,8 +330,8 @@ def main(): ) args = arg_parser.parse_args() with torch.inference_mode(): - # test_static_cache_model(args) - # test_static_cache_lowering(args) + test_static_cache_model(args) + test_static_cache_lowering(args) test_static_cache_with_torch_tensorrt(args) diff --git a/examples/dynamo/utils.py b/examples/dynamo/llm/utils.py similarity index 85% rename from examples/dynamo/utils.py rename to examples/dynamo/llm/utils.py index 6880d9126e..0ebff2caff 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/llm/utils.py @@ -17,11 +17,12 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): with torch.no_grad(): # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) try: print("Trying to export the model using torch.export.export()..") # strict=False only enables aotautograd tracing and excludes dynamo. ep = torch.export.export( - model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False + model, args=(inputs,), kwargs={"position_ids":position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False ) except: print( @@ -30,8 +31,9 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): # This API is used to express the constraint violation guards as asserts in the graph. ep = torch.export._trace._export( model, - (inputs,), - dynamic_shapes=({1: seq_len},), + args=(inputs,), + kwargs={"position_ids":position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False, allow_complex_guards_as_runtime_asserts=True, ) @@ -54,8 +56,8 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): # placeholder nodes are expected to be in the following order: # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] - # The first two inputs are input_ids and is_causal. The last two inputs are start_idx and end_idx. In between are the KV cache tensors. - kv_cache_inputs = placeholder_nodes[2:-2] + # The first two inputs are input_ids, position_ids. The last three inputs are start_idx, end_idx and is_causal. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-3] zeroed_kv_cache_inputs = [] for input in kv_cache_inputs: zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) @@ -75,9 +77,11 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr ) isl = input_seq.shape[1] osl = max_output_seq_length - isl + num_tokens_generated = 0 while num_tokens_generated < osl: - outputs = model(input_seq) + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + outputs = model(input_seq, position_ids) logits = outputs.logits next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) @@ -86,7 +90,7 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr # TODO: Handle batch in this check if not benchmark and stopping_criteria(input_seq, logits).item(): break - # breakpoint() + return input_seq def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id): @@ -95,6 +99,7 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id """ start_idx = 0 end_idx = input_seq.shape[1] + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() output_seq = input_seq.clone() # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL logits_concat = [] @@ -102,8 +107,8 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id kv_cache = get_zeroed_kv_cache_inputs(model) while end_idx < max_output_seq_length: is_causal = True if input_seq.shape[1] > 1 else False - # breakpoint() - input_signature = (input_seq, is_causal, *kv_cache, start_idx, end_idx) + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx, is_causal) logits_keys_values = model(*input_signature) num_tokens_generated += 1 logits = logits_keys_values[0] @@ -116,7 +121,6 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id start_idx = end_idx end_idx = start_idx + 1 lkv = torch.cat(logits_concat, dim=1) - # breakpoint() return output_seq def time_generate( diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index 7436f31939..288f2fdfe6 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -19,11 +19,13 @@ # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_efficient_attention.default + torch.ops.aten._scaled_dot_product_efficient_attention.default, None +) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_flash_attention.default, None ) -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, @@ -71,6 +73,10 @@ def replace_variants_of_sdpa( query, key, value, dropout_p, is_causal, return_debug_mask = ( node.args ) + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = ( + node.args + ) elif len(node.args) == 3: query, key, value = node.args dropout_p = 0.0 @@ -85,14 +91,13 @@ def replace_variants_of_sdpa( ) modified_input_args = (query, key, value, None, dropout_p, is_causal) - # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)}, + kwargs={"scale": node.kwargs.get("scale", None), "use_fp32_acc": settings.use_fp32_acc}, ) # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. @@ -113,7 +118,7 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) - logger.info( + logger.debug( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 903324dff5..4d51932a20 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -62,25 +62,22 @@ def scaled_dot_product_attention( ) -> TRTTensor: # TODO: Handle attn_mask and is_causal arguments in the future query, key, value, attn_mask, dropout_p, is_causal = args - logger.info( - "Ignoring attn_mask and is_causal arguments provided by the original graph. " - "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " - "and for generate phase, is_causal=False since we pass only 1 input token at a time" - ) # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) + use_fp32_acc = kwargs.get("use_fp32_acc", False) + query_dtype = query.dtype + # if use_fp32_acc and query_dtype == trt.float16: + # query = cast_trt_tensor( + # ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir + # ) + # key = cast_trt_tensor( + # ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir + # ) + if scale is None: scale = query.shape[-1] if scale < 0: @@ -90,30 +87,45 @@ def scaled_dot_product_attention( else: # static shape sqrt_scaled = math.sqrt(scale) - scaled = impl.elementwise.div( + key = impl.elementwise.div( ctx, target, source_ir, name + "_scale", - mm, + key, sqrt_scaled, ) else: - scaled = impl.elementwise.mul( + key = impl.elementwise.mul( ctx, target, source_ir, name + "_scale", - mm, + key, scale, ) + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + # if use_fp32_acc: + # mm = cast_trt_tensor( + # ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir + # ) + # If is_causal is True, we need to generate a causal mask if is_causal: L, S = query.shape[-2], key.shape[-2] if L >= 0 and S >= 0: # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") @@ -133,7 +145,7 @@ def scaled_dot_product_attention( ctx, target, source_ir, name + "_logical_not", tril_tensor ) temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir ) one_minus_temp_mask = impl.elementwise.sub( ctx, @@ -148,15 +160,15 @@ def scaled_dot_product_attention( ) scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias ) else: - scaled_add_attn_bias = scaled - + scaled_add_attn_bias = mm + # Create a if condition to check if is_causal is True if isinstance(is_causal, TRTTensor): if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, mm if_layer.set_condition(condition) output_layer = if_layer.add_output(true_branch, false_branch) scaled_add_attn_bias = output_layer.get_output(0) @@ -164,6 +176,13 @@ def scaled_dot_product_attention( softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) + # if use_fp32_acc: + # softmax = cast_trt_tensor( + # ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir + # ) + # value = cast_trt_tensor( + # ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir + # ) out = impl.matmul.matrix_multiply( ctx, target, @@ -172,5 +191,9 @@ def scaled_dot_product_attention( softmax, value, ) + # if use_fp32_acc: + # out = cast_trt_tensor( + # ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir + # ) return out diff --git a/examples/dynamo/test_sdpa.py b/examples/dynamo/test_sdpa.py deleted file mode 100644 index 43c6e285b2..0000000000 --- a/examples/dynamo/test_sdpa.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import TestCase -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers import AutoModelForCausalLM -import torch_tensorrt -from contextlib import nullcontext -import argparse - -# llama2_model_name = "meta-llama/Llama-2-7b-hf" -llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" -llama_model = AutoModelForCausalLM.from_pretrained( - llama3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ).eval().cuda() -LLAMA_CONFIG = llama_model.config - -def test_llama_attention(args): - class LlamaAttentionBlock(nn.Module): - def __init__(self): - super().__init__() - self.config = LLAMA_CONFIG - self.attn = LlamaAttention( - config=self.config, - layer_idx=0 - ) - def forward(self, hidden_states, position_embeddings): - attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) - return attn_output - - DTYPE = torch.float32 - # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) - model = llama_model.model.layers[0].self_attn.to(DTYPE) - # llama3 - # hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() - # position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) - hidden_states = torch.load("hidden_states.pt") - position_embeddings = torch.load("position_embeddings.pt") - # breakpoint() - pyt_output = model(hidden_states, position_embeddings, None) - - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - - with torch_tensorrt.logging.debug(): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=True) - trt_output = trt_model(hidden_states, position_embeddings, None) - if isinstance(pyt_output, tuple): - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - else: - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") - - -def test_llama_decoder(args): - class LlamaDecoder(nn.Module): - def __init__(self): - super().__init__() - self.config = LLAMA_CONFIG - self.decoder_layer = LlamaDecoderLayer( - config=self.config, - layer_idx=0 - ) - def forward(self, hidden_states, position_embeddings): - decoder_output = self.decoder_layer(hidden_states, position_embeddings=position_embeddings) - return decoder_output[0] - - DTYPE = torch.float32 - model = LlamaDecoder().eval().cuda().to(DTYPE) - # llama3 - hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings], - enabled_precisions={torch.float32}, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - args = arg_parser.parse_args() - with torch.inference_mode(): - test_llama_attention(args) - # test_llama_decoder(args) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 6fb61b0036..14c379c3d5 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -443,9 +443,9 @@ def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tenso return aten._reshape_copy.default(x, size) -# @register_torch_trt_decomposition( -# aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS -# ) +@register_torch_trt_decomposition( + aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS +) def scaled_dot_product_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -489,9 +489,9 @@ def scaled_dot_product_attention_decomposition( return attn_weight @ value -# @register_torch_trt_decomposition( -# aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS -# ) +@register_torch_trt_decomposition( + aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS +) def scaled_dot_product_flash_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -518,9 +518,9 @@ def scaled_dot_product_flash_attention_decomposition( return attn, None, None, None, 0, 0, None, None, None -# @register_torch_trt_decomposition( -# aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS -# ) +@register_torch_trt_decomposition( + aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS +) def scaled_dot_product_efficient_attention_decomposition( query: torch.Tensor, key: torch.Tensor, @@ -538,9 +538,9 @@ def scaled_dot_product_efficient_attention_decomposition( return attn, None, None, None -# @register_torch_trt_decomposition( -# aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS -# ) +@register_torch_trt_decomposition( + aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS +) def scaled_dot_product_cudnn_attention_decomposition( query: torch.Tensor, key: torch.Tensor, From 93097256c5b09748dba7d941b1cabacfc033748e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 4 Jun 2025 09:17:06 +0000 Subject: [PATCH 13/30] chore: updates --- examples/dynamo/llama_benchmark.py | 77 ------ examples/dynamo/llm/run_llm.py | 6 +- examples/dynamo/llm/test_gemma.py | 258 ++++++++++++++++++ .../dynamo/llm/test_qwen2.5_components.py | 1 - examples/dynamo/llm/test_qwen3.py | 175 ++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 4 - 6 files changed, 436 insertions(+), 85 deletions(-) delete mode 100644 examples/dynamo/llama_benchmark.py create mode 100644 examples/dynamo/llm/test_gemma.py create mode 100644 examples/dynamo/llm/test_qwen3.py diff --git a/examples/dynamo/llama_benchmark.py b/examples/dynamo/llama_benchmark.py deleted file mode 100644 index d08c477456..0000000000 --- a/examples/dynamo/llama_benchmark.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import timeit - -USE_CACHE = True -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -MAX_NEW_TOKENS = 128 - - -def main(): - # Initialize model and tokenizer - print("Loading model and tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - torch_dtype=torch.float16, - use_cache=False, - device_map="auto" - ) - model.generation_config.cache_implementation = "static" - model.forward = torch.compile(model.forward) - - # Prepare input prompt - word = "What" - # Tokenize the word - word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence - # Repeat the token 2048 times - input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device - print(f"Input tensor shape: {input_ids.shape}") - - # # Warm-up pass - print("Running warm-up pass...") - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - - # Benchmark loop - print("Running benchmark...") - num_iterations = 10 - total_time = 0 - timings = [] - - for i in range(num_iterations): - start_time = timeit.default_timer() - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - end_time = timeit.default_timer() - generation_time = end_time - start_time - total_time += generation_time - timings.append(generation_time) - - # Decode and print first iteration output - # if i == 0: - # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) - # print("\nFirst generation output:") - # print(output_text) - - # Calculate and print statistics - average_time = total_time / num_iterations - print(f"\nPerformance Statistics:") - print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") - print(f"Average tokens per second: {100/average_time:.2f}") - print("\nIndividual timings (ms):") - for i, t in enumerate(timings): - print(f"Iteration {i+1}: {t*1000:.2f}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index 942e5a621b..b1555f8d1e 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -41,7 +41,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=1 + num_hidden_layers=2 ) .eval() .cuda() @@ -59,7 +59,7 @@ def get_model(args): def compile_torchtrt(model, input_ids, args): max_seq_len = input_ids.shape[1] + args.num_tokens ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) # Set precision specific flags use_fp32_acc = False use_explicit_typing = False @@ -76,7 +76,7 @@ def compile_torchtrt(model, input_ids, args): with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile( ep, - inputs=[input_ids], + inputs=[input_ids, position_ids], enabled_precisions=enabled_precisions, # truncate_double=True, use_explicit_typing=use_explicit_typing, diff --git a/examples/dynamo/llm/test_gemma.py b/examples/dynamo/llm/test_gemma.py new file mode 100644 index 0000000000..dc665ce61b --- /dev/null +++ b/examples/dynamo/llm/test_gemma.py @@ -0,0 +1,258 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention, Gemma3DecoderLayer +from transformers.models.gemma3.configuration_gemma3 import Gemma3Config +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + + +ATOL = 1e-5 +RTOL = 1e-5 + + +gemma3_model_name = "google/gemma-3-1b-it" +gemma3_model = AutoModelForCausalLM.from_pretrained( + gemma3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +GEMMA3_CONFIG = gemma3_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + + +def test_gemma3_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # gemma3 + hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 256), dtype=DTYPE).cuda(), torch.randn((1, 5, 256), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_gemma3_attention_with_static_cache(args): + + import static_cache_v2 + DTYPE = torch.float32 + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + +def test_gemma3_decoder(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + model = gemma3_model.model.layers[0].to(DTYPE) + # model.self_attn.is_sliding = False + + # gemma3 + hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda() + position_embeddings_global = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) + position_embeddings_local = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings_global, position_embeddings_local), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings_global, position_embeddings_local], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings_global, position_embeddings_local) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + # breakpoint() + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +def test_gemma3_decoder_with_static_cache(args): + + class Gemma3DecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = GEMMA3_CONFIG + self.decoder = Gemma3DecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE)) + + import static_cache_v2 + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings_global = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) + position_embeddings_local = (torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_gemma3_attention(args) + # test_gemma3_attention_with_static_cache(args) + test_gemma3_decoder(args) + # test_gemma3_decoder_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_qwen2.5_components.py b/examples/dynamo/llm/test_qwen2.5_components.py index 7921c9622d..37ffbc5dd5 100644 --- a/examples/dynamo/llm/test_qwen2.5_components.py +++ b/examples/dynamo/llm/test_qwen2.5_components.py @@ -6,7 +6,6 @@ import torch.nn as nn from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import TestCase -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer from transformers.models.llama.configuration_llama import LlamaConfig from transformers import AutoModelForCausalLM import torch_tensorrt diff --git a/examples/dynamo/llm/test_qwen3.py b/examples/dynamo/llm/test_qwen3.py new file mode 100644 index 0000000000..e83419b717 --- /dev/null +++ b/examples/dynamo/llm/test_qwen3.py @@ -0,0 +1,175 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen3_model_name = "Qwen/Qwen3-0.6B" +qwen3_model = AutoModelForCausalLM.from_pretrained( + qwen3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +QWEN_CONFIG = qwen3_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen3_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_qwen3_decoder(args): + + class QwenDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = QWEN_CONFIG + self.model = model + def forward(self, hidden_states, position_ids, position_embeddings): + return self.model(hidden_states, position_ids=position_ids, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = QwenDecoderLayerBlock(qwen3_model.model.layers[0].to(DTYPE)) + # qwen3 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_ids, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_ids, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_ids, position_embeddings], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_ids, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +def test_qwen3_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = qwen3_model.model.to(DTYPE) + # qwen3 + input_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).cuda().unsqueeze(0) + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (input_ids, position_ids), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[input_ids, position_ids], + enabled_precisions={torch.float32}, + use_python_runtime=True, + disable_tf32=True, + debug=args.debug) + # breakpoint() + trt_output = trt_model(input_ids, position_ids) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[2] - trt_output[2]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP32", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_attention(args) + # test_qwen3_decoder(args) + test_qwen3_model(args) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b89be3ac2a..3700057fd7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -691,10 +691,6 @@ def compile( ) gm = exported_program.module() - exported_program.module().to("cpu") - torch.cuda.empty_cache() - import gc - gc.collect() logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module From a50e7acdd6e54f665dbf9848e0a35589969cf206 Mon Sep 17 00:00:00 2001 From: Chengzhe Xu Date: Thu, 5 Jun 2025 03:22:23 +0000 Subject: [PATCH 14/30] chore: updates --- examples/dynamo/llm/dynamic_cache.py | 20 ++--- examples/dynamo/llm/llama_benchmark.py | 78 ++++++++++++++++++++ examples/dynamo/llm/run_llm.py | 23 ++++-- examples/dynamo/llm/test_llama_components.py | 2 +- examples/dynamo/llm/utils.py | 62 ++++++++++++++-- examples/dynamo/sdpa_converter.py | 4 +- 6 files changed, 166 insertions(+), 23 deletions(-) create mode 100644 examples/dynamo/llm/llama_benchmark.py diff --git a/examples/dynamo/llm/dynamic_cache.py b/examples/dynamo/llm/dynamic_cache.py index 3727c4719f..f9d3b597c6 100644 --- a/examples/dynamo/llm/dynamic_cache.py +++ b/examples/dynamo/llm/dynamic_cache.py @@ -146,10 +146,15 @@ def get_static_tensor(tensor: torch.Tensor): v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) kv_inputs.append((k_input, v_input)) - return kv_inputs + # Add is_generate as input + is_generate_input = add_graph_input(gm, "is_generate", True) + is_generate_input.meta["val"] = torch.tensor(True) -def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]): + return kv_inputs, is_generate_input + + +def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], is_generate_input: torch.Tensor): """ Insert a torch.cond operation before each scaled_dot_product_attention operation. @@ -164,9 +169,6 @@ def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: sdpa_nodes.append(node) - - # Get the is_causal input node - is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None) # For each SDPA node, insert a torch.cond operation before it for idx, sdpa_node in enumerate(sdpa_nodes): @@ -193,13 +195,13 @@ def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten cond_k_node = gm.graph.create_node( "call_function", torch.ops.higher_order.cond, - args=(is_causal_node, concatenated_k_node, k_node), + args=(is_generate_input, concatenated_k_node, k_node), ) cond_v_node = gm.graph.create_node( "call_function", torch.ops.higher_order.cond, - args=(is_causal_node, concatenated_v_node, v_node), + args=(is_generate_input, concatenated_v_node, v_node), ) sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] @@ -216,13 +218,13 @@ def insert_dynamic_kv_cache( """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + kv_inputs, is_generate_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm) # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases - gm = insert_torch_cond_before_sdpa(gm, kv_inputs) + gm = insert_torch_cond_before_sdpa(gm, kv_inputs, is_generate_input) gm = clean_up_graph_after_modifications(gm) diff --git a/examples/dynamo/llm/llama_benchmark.py b/examples/dynamo/llm/llama_benchmark.py new file mode 100644 index 0000000000..9c90011b91 --- /dev/null +++ b/examples/dynamo/llm/llama_benchmark.py @@ -0,0 +1,78 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import timeit + +USE_CACHE = False +# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_NEW_TOKENS = 128 + + +def main(): + # Initialize model and tokenizer + print("Loading model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.float16, + use_cache=False, + device_map="auto" + ) + # model.generation_config.cache_implementation = "static" + # model.forward = torch.compile(model.forward) + + # Prepare input prompt + word = "What" + # Tokenize the word + word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence + # Repeat the token 2048 times + input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + print(f"Input tensor shape: {input_ids.shape}") + + # # Warm-up pass + print("Running warm-up pass...") + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + + # Benchmark loop + print("Running benchmark...") + num_iterations = 10 + total_time = 0 + timings = [] + + for i in range(num_iterations): + start_time = timeit.default_timer() + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + end_time = timeit.default_timer() + generation_time = end_time - start_time + total_time += generation_time + timings.append(generation_time) + + # Decode and print first iteration output + # if i == 0: + # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # print("\nFirst generation output:") + # print(output_text) + + # Calculate and print statistics + average_time = total_time / num_iterations + print(f"\nPerformance Statistics:") + print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") + print(f"Average tokens per second: {100/average_time:.2f}") + print("\nIndividual timings (ms):") + for i, t in enumerate(timings): + print(f"Iteration {i+1}: {t*1000:.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index b1555f8d1e..d4b097e319 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -19,7 +19,7 @@ import torch_tensorrt from transformers import AutoModelForCausalLM, AutoTokenizer from contextlib import nullcontext -from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache +from utils import export_llm, generate, recordStats, time_generate, generate_with_static_cache, generate_with_dynamic_cache import sys import os @@ -41,7 +41,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - num_hidden_layers=2 + # num_hidden_layers=2 ) .eval() .cuda() @@ -238,19 +238,32 @@ def measure_perf(trt_model, input_signature, backend_name): trt_model = compile_torchtrt(model, input_ids, args) - if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "dynamic": + if args.cache == "static_v1" or args.cache == "static_v2": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) torch_tensorrt.runtime.set_cudagraphs_mode(True) - trt_gen_tokens = generate_with_kv_cache( + trt_gen_tokens = generate_with_static_cache( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, ) if args.benchmark: trt_timings = time_generate( - generate_with_kv_cache, + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + elif args.cache == "dynamic": + trt_gen_tokens = generate_with_dynamic_cache( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate_with_dynamic_cache, trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, diff --git a/examples/dynamo/llm/test_llama_components.py b/examples/dynamo/llm/test_llama_components.py index 9a57629d97..c925b1b7df 100644 --- a/examples/dynamo/llm/test_llama_components.py +++ b/examples/dynamo/llm/test_llama_components.py @@ -103,7 +103,7 @@ def test_llama_attention_with_static_cache(args): class LlamaAttentionBlock(nn.Module): def __init__(self): super().__init__() - self.config = LLAMA_CONFIG + self.config = LLAMA_CONFIG self.attn = LlamaAttention( config=self.config, layer_idx=0 diff --git a/examples/dynamo/llm/utils.py b/examples/dynamo/llm/utils.py index 0ebff2caff..c807f799f7 100644 --- a/examples/dynamo/llm/utils.py +++ b/examples/dynamo/llm/utils.py @@ -40,9 +40,9 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): return ep -def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): """ - Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. + Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. This function identifies placeholder nodes in the graph that represent KV cache tensors, and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. @@ -64,6 +64,30 @@ def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule): return tuple(zeroed_kv_cache_inputs) +def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-1] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) + + return tuple(zeroed_kv_cache_inputs) + def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): """ @@ -93,9 +117,9 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr return input_seq -def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id): +def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id): """ - Greedy decoding of the model with KV cache. + Greedy decoding of the model with static KV cache. """ start_idx = 0 end_idx = input_seq.shape[1] @@ -104,7 +128,7 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL logits_concat = [] num_tokens_generated = 0 - kv_cache = get_zeroed_kv_cache_inputs(model) + kv_cache = get_zeroed_static_cache_inputs(model) while end_idx < max_output_seq_length: is_causal = True if input_seq.shape[1] > 1 else False position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids @@ -120,9 +144,35 @@ def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id input_seq = next_tokens start_idx = end_idx end_idx = start_idx + 1 - lkv = torch.cat(logits_concat, dim=1) return output_seq +def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with dynamic KV cache. + """ + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + num_output_tokens = max_output_seq_length - input_seq.shape[1] + num_tokens_generated = 0 + kv_cache = get_zeroed_dynamic_cache_inputs(model) + last_position_id = position_ids[-1, -1].item() + breakpoint() + while num_tokens_generated < num_output_tokens: + is_generate = False if input_seq.shape[1] > 1 else True + position_ids = torch.tensor([[last_position_id+1]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids + input_signature = (input_seq, position_ids, *kv_cache, is_generate) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + last_position_id += 1 + return output_seq + + def time_generate( generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 ): diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 4d51932a20..6e43580183 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -76,8 +76,8 @@ def scaled_dot_product_attention( # ) # key = cast_trt_tensor( # ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir - # ) - + # ) + if scale is None: scale = query.shape[-1] if scale < 0: From cbf0d4315ca5a1b69ac372bc0330ca01ad6c9f0e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 5 Jun 2025 09:19:30 +0000 Subject: [PATCH 15/30] chore: set use_fp32_acc to False --- examples/dynamo/sdpa_converter.py | 47 ++++++++++++++++--------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 6e43580183..236a616adf 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -68,15 +68,16 @@ def scaled_dot_product_attention( source_ir = SourceIR.ATEN # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - use_fp32_acc = kwargs.get("use_fp32_acc", False) + use_fp32_acc = False # kwargs.get("use_fp32_acc", False) query_dtype = query.dtype - # if use_fp32_acc and query_dtype == trt.float16: - # query = cast_trt_tensor( - # ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir - # ) - # key = cast_trt_tensor( - # ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir - # ) + + if use_fp32_acc and query_dtype == trt.float16: + query = cast_trt_tensor( + ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir + ) + key = cast_trt_tensor( + ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir + ) if scale is None: scale = query.shape[-1] @@ -115,10 +116,10 @@ def scaled_dot_product_attention( other_matrix_op=trt.MatrixOperation.TRANSPOSE, ) - # if use_fp32_acc: - # mm = cast_trt_tensor( - # ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir - # ) + if use_fp32_acc: + mm = cast_trt_tensor( + ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir + ) # If is_causal is True, we need to generate a causal mask if is_causal: @@ -176,13 +177,13 @@ def scaled_dot_product_attention( softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) - # if use_fp32_acc: - # softmax = cast_trt_tensor( - # ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir - # ) - # value = cast_trt_tensor( - # ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir - # ) + if use_fp32_acc: + softmax = cast_trt_tensor( + ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir + ) + value = cast_trt_tensor( + ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir + ) out = impl.matmul.matrix_multiply( ctx, target, @@ -191,9 +192,9 @@ def scaled_dot_product_attention( softmax, value, ) - # if use_fp32_acc: - # out = cast_trt_tensor( - # ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir - # ) + if use_fp32_acc: + out = cast_trt_tensor( + ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir + ) return out From 817be62bff3e05ac2a21645eb4c7d28578d4b981 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 6 Jun 2025 22:20:47 +0000 Subject: [PATCH 16/30] chore: updates --- examples/dynamo/llm/llama_benchmark.py | 78 ------- examples/dynamo/llm/run_llm.py | 6 +- examples/dynamo/llm/test_llama_components.py | 202 ++++++++++++++++--- examples/dynamo/llm/test_static_cache.py | 55 ++++- examples/dynamo/llm/utils.py | 4 +- examples/dynamo/register_sdpa.py | 10 +- examples/dynamo/sdpa_converter.py | 22 +- 7 files changed, 240 insertions(+), 137 deletions(-) delete mode 100644 examples/dynamo/llm/llama_benchmark.py diff --git a/examples/dynamo/llm/llama_benchmark.py b/examples/dynamo/llm/llama_benchmark.py deleted file mode 100644 index 9c90011b91..0000000000 --- a/examples/dynamo/llm/llama_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import timeit - -USE_CACHE = False -# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -MODEL_NAME = "Qwen/Qwen3-0.6B" -MAX_NEW_TOKENS = 128 - - -def main(): - # Initialize model and tokenizer - print("Loading model and tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - torch_dtype=torch.float16, - use_cache=False, - device_map="auto" - ) - # model.generation_config.cache_implementation = "static" - # model.forward = torch.compile(model.forward) - - # Prepare input prompt - word = "What" - # Tokenize the word - word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence - # Repeat the token 2048 times - input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device - print(f"Input tensor shape: {input_ids.shape}") - - # # Warm-up pass - print("Running warm-up pass...") - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - - # Benchmark loop - print("Running benchmark...") - num_iterations = 10 - total_time = 0 - timings = [] - - for i in range(num_iterations): - start_time = timeit.default_timer() - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - end_time = timeit.default_timer() - generation_time = end_time - start_time - total_time += generation_time - timings.append(generation_time) - - # Decode and print first iteration output - # if i == 0: - # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) - # print("\nFirst generation output:") - # print(output_text) - - # Calculate and print statistics - average_time = total_time / num_iterations - print(f"\nPerformance Statistics:") - print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") - print(f"Average tokens per second: {100/average_time:.2f}") - print("\nIndividual timings (ms):") - for i, t in enumerate(timings): - print(f"Iteration {i+1}: {t*1000:.2f}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index d4b097e319..c4ad0ab7e9 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -41,7 +41,6 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=2 ) .eval() .cuda() @@ -209,6 +208,7 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_gen_tokens = None pyt_timings = None pyt_stats = None + if args.enable_pytorch_run: pyt_gen_tokens = generate( model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id @@ -235,9 +235,9 @@ def measure_perf(trt_model, input_signature, backend_name): elif args.cache == "dynamic": import dynamic_cache - + # Compile the model with Torch-TensorRT trt_model = compile_torchtrt(model, input_ids, args) - + if args.cache == "static_v1" or args.cache == "static_v2": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. diff --git a/examples/dynamo/llm/test_llama_components.py b/examples/dynamo/llm/test_llama_components.py index c925b1b7df..c0445e1590 100644 --- a/examples/dynamo/llm/test_llama_components.py +++ b/examples/dynamo/llm/test_llama_components.py @@ -33,17 +33,6 @@ LLAMA_CONFIG = llama_model.config def test_llama_attention(args): - class LlamaAttentionBlock(nn.Module): - def __init__(self): - super().__init__() - self.config = LLAMA_CONFIG - self.attn = LlamaAttention( - config=self.config, - layer_idx=0 - ) - def forward(self, hidden_states, position_embeddings): - attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) - return attn_output DTYPE = torch.float32 if args.precision == "FP16": @@ -71,10 +60,17 @@ def forward(self, hidden_states, position_embeddings): position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + from torch.export._trace import _export + # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False) + ep = _export( + model, + args=(hidden_states, position_embeddings, None), + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile(ep, @@ -87,11 +83,11 @@ def forward(self, hidden_states, position_embeddings): trt_output = trt_model(hidden_states, position_embeddings, None) if isinstance(pyt_output, tuple): print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - breakpoint() assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) else: print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + def print_diff(tensor1, tensor2, prefix=""): """ @@ -103,7 +99,7 @@ def test_llama_attention_with_static_cache(args): class LlamaAttentionBlock(nn.Module): def __init__(self): super().__init__() - self.config = LLAMA_CONFIG + self.config = LLAMA_CONFIG self.attn = LlamaAttention( config=self.config, layer_idx=0 @@ -113,6 +109,23 @@ def forward(self, hidden_states, position_embeddings): return attn_output DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} model = llama_model.model.layers[0].self_attn.to(DTYPE) # Inputs @@ -131,15 +144,16 @@ def forward(self, hidden_states, position_embeddings): seq_len = torch.export.Dim("seq_len", min=2, max=2176) dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - import register_sdpa - import static_cache2 + import static_cache_v2 with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile(ep, inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions={torch.float32}, + enabled_precisions=enabled_precisions, disable_tf32=True, debug=args.debug, - # offload_module_to_cpu=True, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, use_python_runtime=True) # Test Prefill @@ -167,8 +181,38 @@ def forward(self, hidden_states, position_embeddings): def test_llama_decoder(args): + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + DTYPE = torch.float32 - model = llama_model.model.layers[0].to(DTYPE) + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) # llama3 hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) @@ -181,12 +225,14 @@ def test_llama_decoder(args): with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile(ep, inputs=[hidden_states, position_embeddings], - enabled_precisions={torch.float32}, - debug=args.debug) + enabled_precisions=enabled_precisions, + debug=args.debug, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing) trt_output = trt_model(hidden_states, position_embeddings) - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) def test_llama_decoder_with_static_cache(args): @@ -202,6 +248,23 @@ def forward(self, hidden_states, position_embeddings): return self.model(hidden_states, position_embeddings=position_embeddings) DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) # Inputs @@ -220,15 +283,16 @@ def forward(self, hidden_states, position_embeddings): seq_len = torch.export.Dim("seq_len", min=2, max=2176) dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) - import register_sdpa - import static_cache2 + import static_cache_v2 with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile(ep, arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions={torch.float32}, + enabled_precisions=enabled_precisions, disable_tf32=True, debug=args.debug, # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, use_python_runtime=True) # Test Prefill @@ -253,9 +317,83 @@ def forward(self, hidden_states, position_embeddings): hidden_states = hidden_states_full position_embeddings = position_embeddings_full + +def test_llama_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"position_ids":position_ids} + from torch.export._trace import _export + ep = _export(model, args=(input_ids,), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes, strict=False, allow_complex_guards_as_runtime_asserts=True) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True) + + trt_output = trt_model(input_ids, position_ids) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + breakpoint() + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + def test_llama_model_with_static_cache(args): DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} model = llama_model.model.to(DTYPE) # Inputs @@ -276,16 +414,17 @@ def test_llama_model_with_static_cache(args): kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids} ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes) - import register_sdpa - import static_cache2 + import static_cache_v2 with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): trt_model = torch_tensorrt.dynamo.compile(ep, arg_inputs=[], kwarg_inputs=kwarg_inputs, - enabled_precisions={torch.float32}, + enabled_precisions=enabled_precisions, disable_tf32=True, debug=args.debug, # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, use_python_runtime=True) # Test Prefill @@ -329,8 +468,9 @@ def test_llama_model_with_static_cache(args): ) args = arg_parser.parse_args() with torch.inference_mode(): - test_llama_attention(args) + # test_llama_attention(args) # test_llama_decoder(args) + test_llama_model(args) # test_llama_attention_with_static_cache(args) # test_llama_decoder_with_static_cache(args) # test_llama_model_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_static_cache.py b/examples/dynamo/llm/test_static_cache.py index 538cbc3f34..13cb384419 100644 --- a/examples/dynamo/llm/test_static_cache.py +++ b/examples/dynamo/llm/test_static_cache.py @@ -4,7 +4,6 @@ import torch_tensorrt from contextlib import nullcontext import argparse -import register_sdpa from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer from transformers.models.llama.configuration_llama import LlamaConfig from transformers import AutoModelForCausalLM @@ -13,6 +12,11 @@ post_lowering, pre_export_lowering, ) +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) ATOL = 1e-5 RTOL = 1e-5 @@ -73,7 +77,7 @@ def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - breakpoint() + if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() @@ -102,6 +106,46 @@ def print_diff(tensor1, tensor2, prefix=""): """ print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_no_cache_model_with_torch_tensorrt(args): + """ + Test the no cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + # q = torch.randn(1, 32, 6, 64).cuda() + # k = torch.randn(1, 32, 6, 64).cuda() + # v = torch.randn(1, 32, 6, 64).cuda() + q = torch.load("query.pt") + k = torch.load("key.pt") + v = torch.load("value.pt") + out_no_cache = model_no_cache(q, k, v) + out_eager = eager_sdpa(q, k, v, is_causal=True) + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + # Export the model + exported_program = torch.export.export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : q_seq_len}, {2 : q_seq_len}), + strict=False + ) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + min_block_size=1, + ) + out_trt = trt_model(q, k, v) + + print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager") + print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt") + print_diff(out_eager, out_trt, "out_eager vs out_trt") + breakpoint() + + def test_static_cache_model(args): """ Test the static cache model @@ -330,9 +374,10 @@ def main(): ) args = arg_parser.parse_args() with torch.inference_mode(): - test_static_cache_model(args) - test_static_cache_lowering(args) - test_static_cache_with_torch_tensorrt(args) + test_no_cache_model_with_torch_tensorrt(args) + # test_static_cache_model(args) + # test_static_cache_lowering(args) + # test_static_cache_with_torch_tensorrt(args) if __name__ == "__main__": diff --git a/examples/dynamo/llm/utils.py b/examples/dynamo/llm/utils.py index c807f799f7..941856ada2 100644 --- a/examples/dynamo/llm/utils.py +++ b/examples/dynamo/llm/utils.py @@ -105,7 +105,7 @@ def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=Tr num_tokens_generated = 0 while num_tokens_generated < osl: position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() - outputs = model(input_seq, position_ids) + outputs = model(input_seq, position_ids=position_ids) logits = outputs.logits next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) @@ -126,7 +126,6 @@ def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_toke position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() output_seq = input_seq.clone() # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL - logits_concat = [] num_tokens_generated = 0 kv_cache = get_zeroed_static_cache_inputs(model) while end_idx < max_output_seq_length: @@ -136,7 +135,6 @@ def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_toke logits_keys_values = model(*input_signature) num_tokens_generated += 1 logits = logits_keys_values[0] - logits_concat.append(logits) kv_cache = logits_keys_values[1:] next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index 288f2fdfe6..906673a806 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -61,6 +61,7 @@ def replace_variants_of_sdpa( elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 + else: raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" @@ -85,12 +86,9 @@ def replace_variants_of_sdpa( raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" ) - if attn_mask is not None: - logger.warning( - f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." - ) - - modified_input_args = (query, key, value, None, dropout_p, is_causal) + + logger.warning(f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations.") + modified_input_args = (query, key, value, None, dropout_p, True) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 236a616adf..db9211253a 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -68,17 +68,9 @@ def scaled_dot_product_attention( source_ir = SourceIR.ATEN # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - use_fp32_acc = False # kwargs.get("use_fp32_acc", False) + use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype - - if use_fp32_acc and query_dtype == trt.float16: - query = cast_trt_tensor( - ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir - ) - key = cast_trt_tensor( - ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir - ) - + if scale is None: scale = query.shape[-1] if scale < 0: @@ -105,7 +97,15 @@ def scaled_dot_product_attention( key, scale, ) - + + if use_fp32_acc and query_dtype == trt.float16: + query = cast_trt_tensor( + ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir + ) + key = cast_trt_tensor( + ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir + ) + mm = impl.matmul.matrix_multiply( ctx, target, From 7a0663587572814108c85798fa628fe4f211460a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 7 Jun 2025 00:00:46 +0000 Subject: [PATCH 17/30] chore: updates --- examples/dynamo/llm/cache_utils.py | 36 +++++++++++++++++++++++++ examples/dynamo/llm/dynamic_cache.py | 37 +------------------------- examples/dynamo/llm/run_llm.py | 6 ++++- examples/dynamo/llm/static_cache_v1.py | 2 +- examples/dynamo/llm/static_cache_v2.py | 2 +- 5 files changed, 44 insertions(+), 39 deletions(-) diff --git a/examples/dynamo/llm/cache_utils.py b/examples/dynamo/llm/cache_utils.py index 7a17ad7e65..714d1b5b72 100644 --- a/examples/dynamo/llm/cache_utils.py +++ b/examples/dynamo/llm/cache_utils.py @@ -6,6 +6,42 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.utils._pytree import _LEAF_SPEC from torch._export.utils import _detect_fake_mode_from_gm +import torch_tensorrt +import tensorrt +from typing import Any, Dict, Sequence +from torch.fx.node import Target + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output def get_kv_nodes(gm): """ diff --git a/examples/dynamo/llm/dynamic_cache.py b/examples/dynamo/llm/dynamic_cache.py index f9d3b597c6..e31939fa99 100644 --- a/examples/dynamo/llm/dynamic_cache.py +++ b/examples/dynamo/llm/dynamic_cache.py @@ -1,10 +1,7 @@ import logging -from typing import Dict, List, Tuple, Union, Sequence, Any +from typing import List, Tuple import torch -from torch.fx.node import Target - -import torch_tensorrt from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( _aten_lowering_pass, @@ -15,41 +12,9 @@ ) from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op -import tensorrt import torch.utils._pytree as pytree logger = logging.getLogger(__name__) -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) -def cond_converter( - ctx: torch_tensorrt.dynamo.conversion.ConversionContext, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: str, -) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: - """ - Converter for torch.ops.higher_order.cond operation to TensorRT. - - This function handles the conversion of PyTorch's conditional operation to TensorRT. - The conditional operation selects between two tensors based on a boolean predicate. - - Args: - ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context - target (Target): The target operation to convert - args (Tuple[Argument, ...]): The arguments to the operation - kwargs (Dict[str, Argument]): The keyword arguments to the operation - name (str): The name to give to the TensorRT layer - - Returns: - Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) - """ - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = args[0], args[1], args[2] - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - output = output_layer.get_output(0) - - return output def add_kv_as_outputs(gm): """ diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index c4ad0ab7e9..336b122c53 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -41,6 +41,7 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", + # num_hidden_layers=1 ) .eval() .cuda() @@ -232,13 +233,16 @@ def measure_perf(trt_model, input_signature, backend_name): if args.cache == "static_v2": # This import is required to register static v2 KV cache transformations as lowering passes import static_cache_v2 + elif args.cache == "static_v3": + # This import is required to register static v3 KV cache transformations as lowering passes + import static_cache_v3 elif args.cache == "dynamic": import dynamic_cache # Compile the model with Torch-TensorRT trt_model = compile_torchtrt(model, input_ids, args) - if args.cache == "static_v1" or args.cache == "static_v2": + if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "static_v3": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) diff --git a/examples/dynamo/llm/static_cache_v1.py b/examples/dynamo/llm/static_cache_v1.py index 8c278f2fb6..0739177706 100644 --- a/examples/dynamo/llm/static_cache_v1.py +++ b/examples/dynamo/llm/static_cache_v1.py @@ -242,7 +242,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten @_aten_lowering_pass -def insert_kv_cache( +def insert_static_cache_v1( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Insert KV cache ops in the graph""" diff --git a/examples/dynamo/llm/static_cache_v2.py b/examples/dynamo/llm/static_cache_v2.py index e15d6737b1..e659a3176b 100644 --- a/examples/dynamo/llm/static_cache_v2.py +++ b/examples/dynamo/llm/static_cache_v2.py @@ -251,7 +251,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten @_aten_lowering_pass -def insert_kv_cache( +def insert_static_cache_v2( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Insert KV cache ops in the graph""" From f47d6ffe673899ab11134e26e125674bcb1235e2 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 7 Jun 2025 00:01:21 +0000 Subject: [PATCH 18/30] chore: add static_cache_v3 --- examples/dynamo/llm/llm_pyt_benchmark.py | 78 ++++++ examples/dynamo/llm/static_cache_v3.py | 294 +++++++++++++++++++++++ 2 files changed, 372 insertions(+) create mode 100644 examples/dynamo/llm/llm_pyt_benchmark.py create mode 100644 examples/dynamo/llm/static_cache_v3.py diff --git a/examples/dynamo/llm/llm_pyt_benchmark.py b/examples/dynamo/llm/llm_pyt_benchmark.py new file mode 100644 index 0000000000..9ae60576a5 --- /dev/null +++ b/examples/dynamo/llm/llm_pyt_benchmark.py @@ -0,0 +1,78 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import timeit + +USE_CACHE = True +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_NEW_TOKENS = 128 + + +def main(): + # Initialize model and tokenizer + print("Loading model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.float16, + use_cache=USE_CACHE, + device_map="auto" + ) + # model.generation_config.cache_implementation = "static" + # model.forward = torch.compile(model.forward) + + # Prepare input prompt + word = "What" + # Tokenize the word + word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence + # Repeat the token 2048 times + input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + print(f"Input tensor shape: {input_ids.shape}") + + # # Warm-up pass + print("Running warm-up pass...") + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + + # Benchmark loop + print("Running benchmark...") + num_iterations = 10 + total_time = 0 + timings = [] + + for i in range(num_iterations): + start_time = timeit.default_timer() + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + end_time = timeit.default_timer() + generation_time = end_time - start_time + total_time += generation_time + timings.append(generation_time) + + # Decode and print first iteration output + # if i == 0: + # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # print("\nFirst generation output:") + # print(output_text) + + # Calculate and print statistics + average_time = total_time / num_iterations + print(f"\nPerformance Statistics:") + print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") + print(f"Average tokens per second: {100/average_time:.2f}") + print("\nIndividual timings (ms):") + for i, t in enumerate(timings): + print(f"Iteration {i+1}: {t*1000:.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/llm/static_cache_v3.py b/examples/dynamo/llm/static_cache_v3.py new file mode 100644 index 0000000000..6d0db10d52 --- /dev/null +++ b/examples/dynamo/llm/static_cache_v3.py @@ -0,0 +1,294 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + input_ids_meta = input_nodes[0].meta["val"] + seq_len = input_ids_meta.shape[1] + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + # Add is_causal as input + is_causal_input = add_graph_input(gm, "is_causal", True) + is_causal_input.meta["val"] = torch.tensor(True) + + return kv_inputs, start_idx_input, end_idx_input, is_causal_input + + + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_causal_true_node in enumerate(sdpa_nodes): + assert len(sdpa_causal_true_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_causal_true_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_causal_true_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_causal_true_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(cat,), + kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={} + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={} + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={} + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + + # Add the new KV cache nodes as inputs to the SDPA node + sdpa_causal_true_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, True) + sdpa_causal_true_node.name = sdpa_causal_true_node.name + "_causal_true" + + # Add a new SDPA node with is_causal=False + with gm.graph.inserting_after(sdpa_causal_true_node): + sdpa_causal_false_node = gm.graph.create_node( + op=sdpa_causal_true_node.op, + target=sdpa_causal_true_node.target, + args=sdpa_causal_true_node.args, + kwargs=sdpa_causal_true_node.kwargs, + name=sdpa_causal_true_node.name + "_causal_false", + ) + sdpa_causal_false_node.args = sdpa_causal_false_node.args[:-1] + (False,) + + with gm.graph.inserting_after(sdpa_causal_false_node): + # Add a torch.cond op which selects between two SDPA nodes ( one with is_causal=True and one with is_causal=False) + cond_node_args = (is_causal_input, sdpa_causal_true_node, sdpa_causal_false_node) + cond_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(), + ) + sdpa_causal_true_node.replace_all_uses_with(cond_node) + cond_node.args = cond_node_args + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v3( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm + + From 535c6a8341a3258a9c311406a9af50eb3c68c5a6 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 11 Jun 2025 08:40:21 +0000 Subject: [PATCH 19/30] chore: remove conditional branching for causal attention --- examples/dynamo/llm/run_llm.py | 1 - examples/dynamo/llm/test_static_cache.py | 9 +- examples/dynamo/sdpa_converter.py | 100 +++++++++++------------ 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index 336b122c53..6dd855cdd4 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -203,7 +203,6 @@ def measure_perf(trt_model, input_signature, backend_name): input_ids = model_inputs["input_ids"].to(DEVICE) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens # Pyt pyt_gen_tokens = None diff --git a/examples/dynamo/llm/test_static_cache.py b/examples/dynamo/llm/test_static_cache.py index 13cb384419..52807f5e93 100644 --- a/examples/dynamo/llm/test_static_cache.py +++ b/examples/dynamo/llm/test_static_cache.py @@ -17,6 +17,7 @@ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import register_sdpa ATOL = 1e-5 RTOL = 1e-5 @@ -295,7 +296,7 @@ def test_static_cache_with_torch_tensorrt(args): """ Test the static cache model with torch_tensorrt """ - import static_cache2 + import static_cache_v2 model_no_cache = ModelNoCache().eval().cuda() q = torch.randn(1, 32, 2, 64).cuda() @@ -346,7 +347,7 @@ def test_static_cache_with_torch_tensorrt(args): q_full = torch.cat((q, q_curr), dim=2) k_full = torch.cat((k, k_curr), dim=2) v_full = torch.cat((v, v_curr), dim=2) - is_causal = False + is_causal = True out_no_cache = model_no_cache(q_full, k_full, v_full) out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) # breakpoint() @@ -374,10 +375,10 @@ def main(): ) args = arg_parser.parse_args() with torch.inference_mode(): - test_no_cache_model_with_torch_tensorrt(args) + # test_no_cache_model_with_torch_tensorrt(args) # test_static_cache_model(args) # test_static_cache_lowering(args) - # test_static_cache_with_torch_tensorrt(args) + test_static_cache_with_torch_tensorrt(args) if __name__ == "__main__": diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index db9211253a..c60ad915dd 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -66,7 +66,7 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - + is_causal = True # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype @@ -121,58 +121,58 @@ def scaled_dot_product_attention( ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir ) - # If is_causal is True, we need to generate a causal mask - if is_causal: - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, 2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor ) - else: - scaled_add_attn_bias = mm - - # Create a if condition to check if is_causal is True - if isinstance(is_causal, TRTTensor): - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, mm - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - scaled_add_attn_bias = output_layer.get_output(0) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False From 7b7ac045f10e240c1aee61107f3ba631e8af0e29 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 11 Jun 2025 08:40:55 +0000 Subject: [PATCH 20/30] chore: remove conditional branching for causal attention --- examples/dynamo/llm/static_cache_v3.py | 294 ------------------------- 1 file changed, 294 deletions(-) delete mode 100644 examples/dynamo/llm/static_cache_v3.py diff --git a/examples/dynamo/llm/static_cache_v3.py b/examples/dynamo/llm/static_cache_v3.py deleted file mode 100644 index 6d0db10d52..0000000000 --- a/examples/dynamo/llm/static_cache_v3.py +++ /dev/null @@ -1,294 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch.fx import Node - -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -import torch.utils._pytree as pytree -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes -logger = logging.getLogger(__name__) - -SDPA_OP = torch._C._nn.scaled_dot_product_attention - -def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + tuple(kv_cache_for_graph) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - -def add_kv_cache_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors, index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - - # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. - input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] - input_ids_meta = input_nodes[0].meta["val"] - seq_len = input_ids_meta.shape[1] - min_max_opt = extract_var_range_info(seq_len) - max_seq_len = min_max_opt["max"] - - from torch.fx.experimental.symbolic_shapes import ShapeEnv - shape_env = ShapeEnv() - # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive - start_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(start_idx_unbacked_symint >= 0) - torch._check(start_idx_unbacked_symint <= max_seq_len) - - end_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(end_idx_unbacked_symint >= 0) - torch._check(end_idx_unbacked_symint <= max_seq_len) - # Set the symbolic ints as the metadata for start_idx and end_idx inputs - start_idx_input.meta["val"] = start_idx_unbacked_symint - end_idx_input.meta["val"] = end_idx_unbacked_symint - - # Add is_causal as input - is_causal_input = add_graph_input(gm, "is_causal", True) - is_causal_input.meta["val"] = torch.tensor(True) - - return kv_inputs, start_idx_input, end_idx_input, is_causal_input - - - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): - """ - Insert slicing operations before each scaled_dot_product_attention operation. - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == SDPA_OP: - sdpa_nodes.append(node) - kv_cache_for_graph = [] - for idx, sdpa_causal_true_node in enumerate(sdpa_nodes): - assert len(sdpa_causal_true_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_causal_true_node.args)} arguments" - q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_causal_true_node.args - incoming_key, incoming_value = incoming_keys_values[idx] - kv_cache_for_sdpa_node = [] - new_keys_values = [] - for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): - # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - with gm.graph.inserting_before(sdpa_causal_true_node): - slice_1 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_2 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_1, 1), - kwargs={} - ) - slice_3 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_2, 2, None, start_idx_input), - kwargs={} - ) - slice_4 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_3, 3), - kwargs={} - ) - # =============================================== # - # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - slice_5 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_6 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_5, 1), - kwargs={} - ) - slice_7 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_6, 2, end_idx_input), - kwargs={} - ) - slice_8 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_7, 3), - kwargs={} - ) - # =============================================== # - # Concatenate the sliced tensors to build KV cache - cat = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([slice_4, current_key_or_value_node, slice_8], 2), - kwargs={} - ) - # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph - cat.meta.update(key_or_value.meta) - kv_cache_for_sdpa_node.append(cat) - # =============================================== # - # Get the current key and value by indexing the KV cache - slice_9 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(cat,), - kwargs={} - ) - slice_10 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_9, 1), - kwargs={} - ) - slice_11 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_10, 2, None, end_idx_input), - kwargs={} - ) - slice_12 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_11, 3), - kwargs={} - ) - new_keys_values.append(slice_12) - - kv_cache_for_graph.extend(kv_cache_for_sdpa_node) - - # Add the new KV cache nodes as inputs to the SDPA node - sdpa_causal_true_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, True) - sdpa_causal_true_node.name = sdpa_causal_true_node.name + "_causal_true" - - # Add a new SDPA node with is_causal=False - with gm.graph.inserting_after(sdpa_causal_true_node): - sdpa_causal_false_node = gm.graph.create_node( - op=sdpa_causal_true_node.op, - target=sdpa_causal_true_node.target, - args=sdpa_causal_true_node.args, - kwargs=sdpa_causal_true_node.kwargs, - name=sdpa_causal_true_node.name + "_causal_false", - ) - sdpa_causal_false_node.args = sdpa_causal_false_node.args[:-1] + (False,) - - with gm.graph.inserting_after(sdpa_causal_false_node): - # Add a torch.cond op which selects between two SDPA nodes ( one with is_causal=True and one with is_causal=False) - cond_node_args = (is_causal_input, sdpa_causal_true_node, sdpa_causal_false_node) - cond_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(), - ) - sdpa_causal_true_node.replace_all_uses_with(cond_node) - cond_node.args = cond_node_args - - return gm, kv_cache_for_graph - - -@_aten_lowering_pass -def insert_static_cache_v3( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) - - # Build and update the KV cache using computed KV inputs for current token and - # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) - - # Call the function to add KV as outputs - logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - - return gm - - From c1f0053fff5fd62365cb0b5d1aedc6f69f741c2d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 11 Jun 2025 17:04:58 +0000 Subject: [PATCH 21/30] chore: remove is_causal input now that causal attention enhancement is landed --- examples/dynamo/llm/run_llm.py | 5 +---- examples/dynamo/llm/static_cache_v1.py | 14 +++++--------- examples/dynamo/llm/static_cache_v2.py | 16 ++++++---------- examples/dynamo/llm/utils.py | 7 +++---- 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py index 6dd855cdd4..d536cd12e4 100644 --- a/examples/dynamo/llm/run_llm.py +++ b/examples/dynamo/llm/run_llm.py @@ -232,16 +232,13 @@ def measure_perf(trt_model, input_signature, backend_name): if args.cache == "static_v2": # This import is required to register static v2 KV cache transformations as lowering passes import static_cache_v2 - elif args.cache == "static_v3": - # This import is required to register static v3 KV cache transformations as lowering passes - import static_cache_v3 elif args.cache == "dynamic": import dynamic_cache # Compile the model with Torch-TensorRT trt_model = compile_torchtrt(model, input_ids, args) - if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "static_v3": + if args.cache == "static_v1" or args.cache == "static_v2": if args.cudagraph: # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) diff --git a/examples/dynamo/llm/static_cache_v1.py b/examples/dynamo/llm/static_cache_v1.py index 0739177706..943718de2e 100644 --- a/examples/dynamo/llm/static_cache_v1.py +++ b/examples/dynamo/llm/static_cache_v1.py @@ -118,15 +118,11 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint - # Add is_causal as input - is_causal_input = add_graph_input(gm, "is_causal", True) - is_causal_input.meta["val"] = torch.tensor(True) + return kv_inputs, start_idx_input, end_idx_input - return kv_inputs, start_idx_input, end_idx_input, is_causal_input - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): """ Insert slicing operations before each scaled_dot_product_attention operation. """ @@ -236,7 +232,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten kv_cache_for_graph.extend(kv_cache_for_sdpa_node) - sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, is_causal_input) + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, True) return gm, kv_cache_for_graph @@ -248,11 +244,11 @@ def insert_static_cache_v1( """Insert KV cache ops in the graph""" """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) # Build and update the KV cache using computed KV inputs for current token and # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) diff --git a/examples/dynamo/llm/static_cache_v2.py b/examples/dynamo/llm/static_cache_v2.py index e659a3176b..e2a40d39f7 100644 --- a/examples/dynamo/llm/static_cache_v2.py +++ b/examples/dynamo/llm/static_cache_v2.py @@ -123,11 +123,7 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint - # Add is_causal as input - is_causal_input = add_graph_input(gm, "is_causal", True) - is_causal_input.meta["val"] = torch.tensor(True) - - return kv_inputs, start_idx_input, end_idx_input, is_causal_input + return kv_inputs, start_idx_input, end_idx_input def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input): """ @@ -216,7 +212,7 @@ def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_nod return concat_keys_or_values, new_incoming_keys_or_values -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node): +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): """ Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) @@ -244,7 +240,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) # Update the SDPA node arguments with current key and value nodes - sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, is_causal_input) + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, True) # kv_cache_for_graph.extend([k_node, v_node]) return gm, kv_cache_for_graph @@ -257,11 +253,11 @@ def insert_static_cache_v2( """Insert KV cache ops in the graph""" """Perform insertion of kv-caches and attention kernel.""" # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True) + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) # Build and update the KV cache using computed KV inputs for current token and # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) # Call the function to add KV as outputs logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) @@ -272,8 +268,8 @@ def insert_static_cache_v2( new_out_spec = pytree.tree_flatten(new_output_tensors)[1] gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm diff --git a/examples/dynamo/llm/utils.py b/examples/dynamo/llm/utils.py index 941856ada2..c43f90acc5 100644 --- a/examples/dynamo/llm/utils.py +++ b/examples/dynamo/llm/utils.py @@ -56,8 +56,8 @@ def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): # placeholder nodes are expected to be in the following order: # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] - # The first two inputs are input_ids, position_ids. The last three inputs are start_idx, end_idx and is_causal. In between are the KV cache tensors. - kv_cache_inputs = placeholder_nodes[2:-3] + # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-2] zeroed_kv_cache_inputs = [] for input in kv_cache_inputs: zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) @@ -129,9 +129,8 @@ def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_toke num_tokens_generated = 0 kv_cache = get_zeroed_static_cache_inputs(model) while end_idx < max_output_seq_length: - is_causal = True if input_seq.shape[1] > 1 else False position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids - input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx, is_causal) + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) logits_keys_values = model(*input_signature) num_tokens_generated += 1 logits = logits_keys_values[0] From 8301ee667eb2142811fd66c95cf46adbb4b17c62 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Jun 2025 22:30:44 +0000 Subject: [PATCH 22/30] chore: refactor --- examples/dynamo/llm/cache_utils.py | 188 ------- examples/dynamo/llm/dynamic_cache.py | 203 -------- examples/dynamo/llm/llm_pyt_benchmark.py | 78 --- examples/dynamo/llm/run_llm.py | 308 ------------ examples/dynamo/llm/static_cache_v1.py | 266 ---------- examples/dynamo/llm/static_cache_v2.py | 275 ---------- examples/dynamo/llm/test_gemma.py | 258 ---------- examples/dynamo/llm/test_llama_components.py | 476 ------------------ .../dynamo/llm/test_qwen2.5_components.py | 173 ------- examples/dynamo/llm/test_qwen3.py | 175 ------- examples/dynamo/llm/test_static_cache.py | 385 -------------- examples/dynamo/llm/utils.py | 216 -------- examples/dynamo/register_sdpa.py | 122 ----- examples/dynamo/sdpa_converter.py | 200 -------- examples/dynamo/torch_export_gpt2.py | 98 ---- examples/dynamo/torch_export_llama2.py | 102 ---- 16 files changed, 3523 deletions(-) delete mode 100644 examples/dynamo/llm/cache_utils.py delete mode 100644 examples/dynamo/llm/dynamic_cache.py delete mode 100644 examples/dynamo/llm/llm_pyt_benchmark.py delete mode 100644 examples/dynamo/llm/run_llm.py delete mode 100644 examples/dynamo/llm/static_cache_v1.py delete mode 100644 examples/dynamo/llm/static_cache_v2.py delete mode 100644 examples/dynamo/llm/test_gemma.py delete mode 100644 examples/dynamo/llm/test_llama_components.py delete mode 100644 examples/dynamo/llm/test_qwen2.5_components.py delete mode 100644 examples/dynamo/llm/test_qwen3.py delete mode 100644 examples/dynamo/llm/test_static_cache.py delete mode 100644 examples/dynamo/llm/utils.py delete mode 100644 examples/dynamo/register_sdpa.py delete mode 100644 examples/dynamo/sdpa_converter.py delete mode 100644 examples/dynamo/torch_export_gpt2.py delete mode 100644 examples/dynamo/torch_export_llama2.py diff --git a/examples/dynamo/llm/cache_utils.py b/examples/dynamo/llm/cache_utils.py deleted file mode 100644 index 714d1b5b72..0000000000 --- a/examples/dynamo/llm/cache_utils.py +++ /dev/null @@ -1,188 +0,0 @@ -import torch -from torch.fx import Graph, GraphModule, Node -from typing import Optional, Union, Iterable, List, Tuple -from torch._ops import OpOverloadPacket -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.fx.passes.shape_prop import _extract_tensor_metadata -from torch.utils._pytree import _LEAF_SPEC -from torch._export.utils import _detect_fake_mode_from_gm -import torch_tensorrt -import tensorrt -from typing import Any, Dict, Sequence -from torch.fx.node import Target - -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) -def cond_converter( - ctx: torch_tensorrt.dynamo.conversion.ConversionContext, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: str, -) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: - """ - Converter for torch.ops.higher_order.cond operation to TensorRT. - - This function handles the conversion of PyTorch's conditional operation to TensorRT. - The conditional operation selects between two tensors based on a boolean predicate. - - Args: - ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context - target (Target): The target operation to convert - args (Tuple[Argument, ...]): The arguments to the operation - kwargs (Dict[str, Argument]): The keyword arguments to the operation - name (str): The name to give to the TensorRT layer - - Returns: - Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) - """ - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = args[0], args[1], args[2] - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - output = output_layer.get_output(0) - - return output - -def get_kv_nodes(gm): - """ - Get the key and value nodes from the graph. - """ - kv_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - q_node, k_node, v_node = node.args[:3] - kv_nodes.append((k_node, v_node)) - return kv_nodes - -def get_random_tensor_from_node(node: Node) -> torch.Tensor: - """ - Creates a random tensor based on the shape information in a node's metadata. - For symbolic dimensions, extracts the maximum value from the shape environment. - - Args: - node: A torch.fx.Node object with metadata containing tensor information - - Returns: - A random tensor with shape matching the node's metadata, or None if no valid - tensor information is found - """ - if "val" not in node.meta: - raise ValueError(f"No tensor information found in node metadata for node: {node}") - - fake_tensor = node.meta["val"] - shape = [] - - # Iterate through each dimension and handle symbolic dimensions - for dim in fake_tensor.shape: - if isinstance(dim, torch.SymInt): - # Extract the maximum value from the shape environment - max_val = dim.node.hint - shape.append(max_val) - else: - shape.append(dim) - - # Create a random tensor with the determined shape - dtype = fake_tensor.dtype - device = fake_tensor.device - random_tensor = torch.rand(shape, dtype=dtype, device=device) - - return random_tensor - -def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: - """ - Creates random tensors based on the shape information in node metadata. - For symbolic dimensions, extracts the maximum value from the shape environment. - - Args: - nodes: List of torch.fx.Node objects with metadata - - Returns: - List of random tensors with shapes matching the nodes' metadata - """ - random_tensors = [] - - for node in nodes: - if isinstance(node, Node): - node_tensor = get_random_tensor_from_node(node) - elif isinstance(node, tuple): - node_tensor_list = [] - for n in node: - random_tensor = get_random_tensor_from_node(n) - node_tensor_list.append(random_tensor) - node_tensor = tuple(node_tensor_list) - - random_tensors.append(node_tensor) - - return random_tensors - -def add_graph_input( - gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None -) -> Node: - """Add a graph input to the given GraphModule and return the newly created node. - - NOTE: function does NOT do any graph canonicalization. This is left to the user! - - Args: - gm (GraphModule): The GraphModule to add the input to. - name (str): The name of the input. - val (torch.Tensor): An example tensor to use for the input. - dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] - """ - # check that no dynamic shape is provided... - if dynamic_shape: - raise NotImplementedError("Dynamic shape not supported for adding graph inputs") - - # extract graph and input spec - graph: Graph = gm.graph - - in_spec = graph._codegen.pytree_info.in_spec - in_spec_for_args = in_spec.children_specs[0] - orig_args = graph._codegen.pytree_info.orig_args - assert in_spec_for_args.type is tuple - - # insert input node after currently last input node - node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] - with graph.inserting_after(node_last_input): - in_node = graph.placeholder(name) - in_spec_for_args.children_specs.append(_LEAF_SPEC) - orig_args.append(f"arg_{name}") - - # update pytree info recursively with __post_init__ starting at leaves - def call_post_init(spec): - for child_spec in spec.children_specs: - call_post_init(child_spec) - spec.__post_init__() - - call_post_init(in_spec) - - # set fake tensor information if all required information is available - fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) - if fake_mode and val is not None and isinstance(val, torch.Tensor): - if isinstance(val, FakeTensor): - fake_tensor = val - else: - fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) - in_node.meta["val"] = fake_tensor - in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) - - # return new node... - return in_node - -def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: - """Check if the node is a call to one of the ops.""" - if node.op != "call_function": - return False - # check if it's a single op that's provided - if isinstance(ops, OpOverloadPacket): - ops = [ops] - - # check if it's the op itself instead of an overload - if any(node.target == op for op in ops): - return True - - return False - -def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: - input_nodes: List[Node] = graph.find_nodes(op="placeholder") - output_nodes: List[Node] = graph.find_nodes(op="output") - return (input_nodes, output_nodes) \ No newline at end of file diff --git a/examples/dynamo/llm/dynamic_cache.py b/examples/dynamo/llm/dynamic_cache.py deleted file mode 100644 index e31939fa99..0000000000 --- a/examples/dynamo/llm/dynamic_cache.py +++ /dev/null @@ -1,203 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op -import torch.utils._pytree as pytree -logger = logging.getLogger(__name__) - - -def add_kv_as_outputs(gm): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - # list of MHA kernels we would want to detect and replace - mha_ops = { - torch._C._nn.scaled_dot_product_attention, - } - - # Find all SDPA nodes in the graph - mha_nodes = [] - for node in gm.graph.nodes: - if is_op(node, mha_ops): - mha_nodes.append(node) - - # Iterate through each MHA node to extract shape information - for mha_node in mha_nodes: - if "val" in mha_node.meta and len(mha_node.args) >= 3: - # Get the input nodes (query, key, value) - q_node, k_node, v_node = mha_node.args[:3] - - # Add the copy nodes as outputs to the graph - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + ((k_node, v_node),) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs, (k_node, v_node)) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - - - -def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors and index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - - # Add is_generate as input - is_generate_input = add_graph_input(gm, "is_generate", True) - is_generate_input.meta["val"] = torch.tensor(True) - - return kv_inputs, is_generate_input - - -def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], is_generate_input: torch.Tensor): - """ - Insert a torch.cond operation before each scaled_dot_product_attention operation. - - Args: - gm: The FX GraphModule to modify - - Returns: - The modified GraphModule - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: - sdpa_nodes.append(node) - - # For each SDPA node, insert a torch.cond operation before it - for idx, sdpa_node in enumerate(sdpa_nodes): - - with gm.graph.inserting_before(sdpa_node): - # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) - q_node, k_node, v_node = sdpa_node.args[:3] - incoming_key, incoming_value = incoming_keys_values[idx] - # Create nodes for concatenating k with incoming_key and v with incoming_value - concatenated_k_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension - kwargs={} - ) - concatenated_v_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension - kwargs={} - ) - - # Create the torch.cond node - cond_k_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(is_generate_input, concatenated_k_node, k_node), - ) - - cond_v_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(is_generate_input, concatenated_v_node, v_node), - ) - - sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] - - return gm - - - -@_aten_lowering_pass -def insert_dynamic_kv_cache( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert FlashInfer MHA + KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - - # Add static key and value as inputs to the graph - kv_inputs, is_generate_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) - - # Call the function to add KV as outputs - logits_keys_values = add_kv_as_outputs(gm) - - # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases - gm = insert_torch_cond_before_sdpa(gm, kv_inputs, is_generate_input) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm - - diff --git a/examples/dynamo/llm/llm_pyt_benchmark.py b/examples/dynamo/llm/llm_pyt_benchmark.py deleted file mode 100644 index 9ae60576a5..0000000000 --- a/examples/dynamo/llm/llm_pyt_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import timeit - -USE_CACHE = True -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -# MODEL_NAME = "Qwen/Qwen3-0.6B" -MAX_NEW_TOKENS = 128 - - -def main(): - # Initialize model and tokenizer - print("Loading model and tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - torch_dtype=torch.float16, - use_cache=USE_CACHE, - device_map="auto" - ) - # model.generation_config.cache_implementation = "static" - # model.forward = torch.compile(model.forward) - - # Prepare input prompt - word = "What" - # Tokenize the word - word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence - # Repeat the token 2048 times - input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device - print(f"Input tensor shape: {input_ids.shape}") - - # # Warm-up pass - print("Running warm-up pass...") - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - - # Benchmark loop - print("Running benchmark...") - num_iterations = 10 - total_time = 0 - timings = [] - - for i in range(num_iterations): - start_time = timeit.default_timer() - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE - ) - end_time = timeit.default_timer() - generation_time = end_time - start_time - total_time += generation_time - timings.append(generation_time) - - # Decode and print first iteration output - # if i == 0: - # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) - # print("\nFirst generation output:") - # print(output_text) - - # Calculate and print statistics - average_time = total_time / num_iterations - print(f"\nPerformance Statistics:") - print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") - print(f"Average tokens per second: {100/average_time:.2f}") - print("\nIndividual timings (ms):") - for i, t in enumerate(timings): - print(f"Iteration {i+1}: {t*1000:.2f}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py deleted file mode 100644 index d536cd12e4..0000000000 --- a/examples/dynamo/llm/run_llm.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -.. _torch_export_gpt2: - -Compiling GPT2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. -""" - -import argparse -import copy -import os -import timeit - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from contextlib import nullcontext -from utils import export_llm, generate, recordStats, time_generate, generate_with_static_cache, generate_with_dynamic_cache -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from register_sdpa import * - -DEVICE = torch.device("cuda:0") - -def get_model(args): - with torch.no_grad(): - # Supported list of models: - # - meta-llama/Llama-3.2-1B-Instruct - # - meta-llama/Llama-3.2-3B-Instruct - # - meta-llama/Llama-3.1-8B-Instruct - # - Qwen/Qwen2.5-1.5B-Instruct - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", - # num_hidden_layers=1 - ) - .eval() - .cuda() - ) - if args.precision == "FP16": - model = model.to(torch.float16) - elif args.precision == "BF16": - model = model.to(torch.bfloat16) - else: - model = model.to(torch.float32) - - return model - - -def compile_torchtrt(model, input_ids, args): - max_seq_len = input_ids.shape[1] + args.num_tokens - ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[input_ids, position_ids], - enabled_precisions=enabled_precisions, - # truncate_double=True, - use_explicit_typing=use_explicit_typing, - use_fp32_acc=use_fp32_acc, - device=DEVICE, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - offload_module_to_cpu=True, - min_block_size=args.min_block_size, - ) - - return trt_model - - -def print_outputs(backend_name, gen_tokens, tokenizer): - print(f"========= {backend_name} =========") - print( - f"{backend_name} model generated text: ", - tokenizer.decode(gen_tokens[0], skip_special_tokens=True), - ) - print("===================================") - - - -def measure_perf(trt_model, input_signature, backend_name): - # Measure average time for 10 iterations - import timeit - import numpy as np - - total_time = 0 - iterations = 10 - - print("Running warmup iteration...") - # Warmup run - _ = trt_model(*input_signature) - torch.cuda.synchronize() - - print(f"Measuring performance over {iterations} iterations...") - for i in range(iterations): - start_time = timeit.default_timer() - _ = trt_model(*input_signature) - torch.cuda.synchronize() - end_time = timeit.default_timer() - iter_time = end_time - start_time - total_time += iter_time - # print(f"Iteration {i+1}: {iter_time:.4f} seconds") - - avg_time = total_time / iterations - print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds") - print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second") - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run inference on a model with random input values" - ) - arg_parser.add_argument( - "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model" - ) - arg_parser.add_argument( - "--tokenizer", - type=str, - default="", - help="Name of LLM model tokenizer", - ) - arg_parser.add_argument( - "--prompt", type=str, default="What is parallel programming ?", help="Prompt" - ) - arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") - arg_parser.add_argument( - "--iterations", type=int, default=5, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--min_block_size", type=int, default=1, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--num_tokens", type=int, default=128, help="no. of output tokens to be generated" - ) - arg_parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size used for benchmarking" - ) - arg_parser.add_argument( - "--isl", type=int, default=2048, help="Input sequence length used for benchmarking" - ) - arg_parser.add_argument( - "--enable_pytorch_run", - action="store_true", - help="Enable pytorch run (default: False)" - ) - arg_parser.add_argument( - "--cache", - type=str, - default="", - help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", - ) - arg_parser.add_argument( - "--cudagraph", - action="store_true", - help="Enable cudagraphs (default: False)" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--benchmark", - action="store_true", - help="Enable benchmark (default: False)" - ) - - args = arg_parser.parse_args() - with torch.inference_mode(): - model = get_model(args) - - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) - - # Prepare input for benchmarking or evaluation - if args.benchmark: - input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - else: - model_inputs = tokenizer(args.prompt, return_tensors="pt") - input_ids = model_inputs["input_ids"].to(DEVICE) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens - # Pyt - pyt_gen_tokens = None - pyt_timings = None - pyt_stats = None - - if args.enable_pytorch_run: - pyt_gen_tokens = generate( - model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id - ) - if args.benchmark: - pyt_timings = time_generate( - generate, - model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - pyt_stats = recordStats( - "PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None - ) - - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 - elif args.cache == "dynamic": - import dynamic_cache - - # Compile the model with Torch-TensorRT - trt_model = compile_torchtrt(model, input_ids, args) - - if args.cache == "static_v1" or args.cache == "static_v2": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - trt_gen_tokens = generate_with_static_cache( - trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - ) - - if args.benchmark: - trt_timings = time_generate( - generate_with_static_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - elif args.cache == "dynamic": - trt_gen_tokens = generate_with_dynamic_cache( - trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate_with_dynamic_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - else: - trt_gen_tokens = generate( - trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - - if args.benchmark: - trt_stats = recordStats( - "TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None - ) - - - if not args.benchmark: - if args.enable_pytorch_run: - print_outputs("PyTorch", pyt_gen_tokens, tokenizer) - - print_outputs("TensorRT", trt_gen_tokens, tokenizer) - - if args.enable_pytorch_run: - print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}") - - if args.benchmark: - if args.enable_pytorch_run: - print("=========PyTorch PERFORMANCE============ \n") - print(pyt_stats) - print("===================== \n") - print("=========TensorRT PERFORMANCE============ \n") - print(trt_stats) diff --git a/examples/dynamo/llm/static_cache_v1.py b/examples/dynamo/llm/static_cache_v1.py deleted file mode 100644 index 943718de2e..0000000000 --- a/examples/dynamo/llm/static_cache_v1.py +++ /dev/null @@ -1,266 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch.fx import Node - -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -import torch.utils._pytree as pytree -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes -logger = logging.getLogger(__name__) - -SDPA_OP = torch._C._nn.scaled_dot_product_attention - -def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + tuple(kv_cache_for_graph) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - -def add_kv_cache_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors, index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - - # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. - input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] - input_ids_meta = input_nodes[0].meta["val"] - seq_len = input_ids_meta.shape[1] - min_max_opt = extract_var_range_info(seq_len) - max_seq_len = min_max_opt["max"] - - from torch.fx.experimental.symbolic_shapes import ShapeEnv - shape_env = ShapeEnv() - # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive - start_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(start_idx_unbacked_symint >= 0) - torch._check(start_idx_unbacked_symint <= max_seq_len) - - end_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(end_idx_unbacked_symint >= 0) - torch._check(end_idx_unbacked_symint <= max_seq_len) - # Set the symbolic ints as the metadata for start_idx and end_idx inputs - start_idx_input.meta["val"] = start_idx_unbacked_symint - end_idx_input.meta["val"] = end_idx_unbacked_symint - - return kv_inputs, start_idx_input, end_idx_input - - - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): - """ - Insert slicing operations before each scaled_dot_product_attention operation. - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == SDPA_OP: - sdpa_nodes.append(node) - kv_cache_for_graph = [] - for idx, sdpa_node in enumerate(sdpa_nodes): - assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" - q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args - incoming_key, incoming_value = incoming_keys_values[idx] - kv_cache_for_sdpa_node = [] - new_keys_values = [] - for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): - # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - with gm.graph.inserting_before(sdpa_node): - slice_1 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_2 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_1, 1), - kwargs={} - ) - slice_3 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_2, 2, None, start_idx_input), - kwargs={} - ) - slice_4 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_3, 3), - kwargs={} - ) - # =============================================== # - # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - slice_5 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(key_or_value,), - kwargs={} - ) - slice_6 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_5, 1), - kwargs={} - ) - slice_7 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_6, 2, end_idx_input), - kwargs={} - ) - slice_8 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_7, 3), - kwargs={} - ) - # =============================================== # - # Concatenate the sliced tensors to build KV cache - cat = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([slice_4, current_key_or_value_node, slice_8], 2), - kwargs={} - ) - # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph - cat.meta.update(key_or_value.meta) - kv_cache_for_sdpa_node.append(cat) - # =============================================== # - # Get the current key and value by indexing the KV cache - slice_9 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(cat,), - kwargs={} - ) - slice_10 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_9, 1), - kwargs={} - ) - slice_11 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_10, 2, None, end_idx_input), - kwargs={} - ) - slice_12 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_11, 3), - kwargs={} - ) - new_keys_values.append(slice_12) - - kv_cache_for_graph.extend(kv_cache_for_sdpa_node) - - sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, True) - - return gm, kv_cache_for_graph - - -@_aten_lowering_pass -def insert_static_cache_v1( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) - - # Build and update the KV cache using computed KV inputs for current token and - # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - - # Call the function to add KV as outputs - logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - - return gm - - diff --git a/examples/dynamo/llm/static_cache_v2.py b/examples/dynamo/llm/static_cache_v2.py deleted file mode 100644 index e2a40d39f7..0000000000 --- a/examples/dynamo/llm/static_cache_v2.py +++ /dev/null @@ -1,275 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch.fx import Node - -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -import torch.utils._pytree as pytree -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes -logger = logging.getLogger(__name__) - -SDPA_OP = torch._C._nn.scaled_dot_product_attention - -def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + tuple(kv_cache_for_graph) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - -def add_kv_cache_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors, index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) - - # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx - input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] - # Get the third last input which should be the last value cache node and store the max_seq_len - input_ids_meta = input_nodes[-3].meta["val"] - seq_len = input_ids_meta.shape[2] - - if isinstance(seq_len, torch.SymInt): - min_max_opt = extract_var_range_info(seq_len) - max_seq_len = min_max_opt["max"] - else: - max_seq_len = seq_len - - from torch.fx.experimental.symbolic_shapes import ShapeEnv - shape_env = ShapeEnv() - # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive - start_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(start_idx_unbacked_symint >= 0) - torch._check(start_idx_unbacked_symint <= max_seq_len) - - end_idx_unbacked_symint = shape_env.create_unbacked_symint() - torch._check(end_idx_unbacked_symint >= 0) - torch._check(end_idx_unbacked_symint <= max_seq_len) - # Set the symbolic ints as the metadata for start_idx and end_idx inputs - start_idx_input.meta["val"] = start_idx_unbacked_symint - end_idx_input.meta["val"] = end_idx_unbacked_symint - - return kv_inputs, start_idx_input, end_idx_input - -def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input): - """ - Create slicing and concatenation nodes for KV cache update. - - This function creates the necessary slicing and concatenation nodes to update the KV cache - during the generation process. It takes the SDPA node, the current KV cache node, and the - incoming KV cache node as input. - Returns: - for a particular SDPA node, a tuple containing: - - List of new current KV nodes - - List of updated incoming KV cache nodes - - """ - - # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - with gm.graph.inserting_before(sdpa_node): - slice_1 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(incoming_kv_node,), - kwargs={} - ) - slice_2 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_1, 1), - kwargs={} - ) - slice_3 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_2, 2, None, start_idx_input), - kwargs={} - ) - slice_4 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_3, 3), - kwargs={} - ) - # Concat key_cache[:,:,:start_idx,:] with current key (k) - concat_keys_or_values = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([slice_4, current_kv_node], 2), - kwargs={} - ) - - # =============================================== # - # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim - slice_5 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(incoming_kv_node,), - kwargs={} - ) - slice_6 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_5, 1), - kwargs={} - ) - slice_7 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_6, 2, end_idx_input), - kwargs={} - ) - slice_8 = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(slice_7, 3), - kwargs={} - ) - # =============================================== # - # Concatenate the sliced tensors to build KV cache - new_incoming_keys_or_values = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=([concat_keys_or_values, slice_8], 2), - kwargs={} - ) - # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph - new_incoming_keys_or_values.meta.update(incoming_kv_node.meta) - - return concat_keys_or_values, new_incoming_keys_or_values - -def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): - """ - Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: - concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) - concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) - new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) - new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) - out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == SDPA_OP: - sdpa_nodes.append(node) - kv_cache_for_graph = [] - for idx, sdpa_node in enumerate(sdpa_nodes): - assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" - q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args - incoming_key, incoming_value = incoming_keys_values[idx] - # For keys - new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input) - # For values - new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input) - - # Store the KV cache nodes for the current SDPA node - kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) - - # Update the SDPA node arguments with current key and value nodes - sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, True) - - # kv_cache_for_graph.extend([k_node, v_node]) - return gm, kv_cache_for_graph - - -@_aten_lowering_pass -def insert_static_cache_v2( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - # Add static key and value as inputs to the graph - kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) - - # Build and update the KV cache using computed KV inputs for current token and - # incoming keys and values from previous tokens (which were added as inputs) - gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) - - # Call the function to add KV as outputs - logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm - - diff --git a/examples/dynamo/llm/test_gemma.py b/examples/dynamo/llm/test_gemma.py deleted file mode 100644 index dc665ce61b..0000000000 --- a/examples/dynamo/llm/test_gemma.py +++ /dev/null @@ -1,258 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import TestCase -from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention, Gemma3DecoderLayer -from transformers.models.gemma3.configuration_gemma3 import Gemma3Config -from transformers import AutoModelForCausalLM -import torch_tensorrt -from contextlib import nullcontext -import argparse -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from register_sdpa import * - - -ATOL = 1e-5 -RTOL = 1e-5 - - -gemma3_model_name = "google/gemma-3-1b-it" -gemma3_model = AutoModelForCausalLM.from_pretrained( - gemma3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ).eval().cuda() -GEMMA3_CONFIG = gemma3_model.config - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") - - -def test_gemma3_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = gemma3_model.model.layers[0].self_attn.to(DTYPE) - - # gemma3 - hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 5, 256), dtype=DTYPE).cuda(), torch.randn((1, 5, 256), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings, None) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - -def test_gemma3_attention_with_static_cache(args): - - import static_cache_v2 - DTYPE = torch.float32 - model = gemma3_model.model.layers[0].self_attn.to(DTYPE) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) - key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_python_runtime=True) - - # Test Prefill - trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) - print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() - position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) - - is_causal = False - out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) - out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) - out_pyt = out_no_cache[:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - -def test_gemma3_decoder(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - model = gemma3_model.model.layers[0].to(DTYPE) - # model.self_attn.is_sliding = False - - # gemma3 - hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda() - position_embeddings_global = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) - position_embeddings_local = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, (hidden_states, position_embeddings_global, position_embeddings_local), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings_global, position_embeddings_local], - enabled_precisions={torch.float32}, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings_global, position_embeddings_local) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - # breakpoint() - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - -def test_gemma3_decoder_with_static_cache(args): - - class Gemma3DecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = GEMMA3_CONFIG - self.decoder = Gemma3DecoderLayer( - config=self.config, - layer_idx=0) - self.model = model - def forward(self, hidden_states, position_embeddings): - return self.model(hidden_states, position_embeddings=position_embeddings) - - DTYPE = torch.float32 - model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE)) - - import static_cache_v2 - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() - position_embeddings_global = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) - position_embeddings_local = (torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda()) - key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_python_runtime=True) - - # Test Prefill - trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) - print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() - position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) - - is_causal = False - out_no_cache = model(hidden_states_full, position_embeddings_full) - - out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) - out_pyt = out_no_cache[0][:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_gemma3_attention(args) - # test_gemma3_attention_with_static_cache(args) - test_gemma3_decoder(args) - # test_gemma3_decoder_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_llama_components.py b/examples/dynamo/llm/test_llama_components.py deleted file mode 100644 index c0445e1590..0000000000 --- a/examples/dynamo/llm/test_llama_components.py +++ /dev/null @@ -1,476 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import TestCase -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers import AutoModelForCausalLM -import torch_tensorrt -from contextlib import nullcontext -import argparse -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from register_sdpa import * -ATOL = 1e-5 -RTOL = 1e-5 - - -# llama2_model_name = "meta-llama/Llama-2-7b-hf" -llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" -llama_model = AutoModelForCausalLM.from_pretrained( - llama3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ).eval().cuda() -LLAMA_CONFIG = llama_model.config - -def test_llama_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) - model = llama_model.model.layers[0].self_attn.to(DTYPE) - # llama3 - hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - from torch.export._trace import _export - # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False) - ep = _export( - model, - args=(hidden_states, position_embeddings, None), - dynamic_shapes=dynamic_shapes, - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings, None) - if isinstance(pyt_output, tuple): - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") - -def test_llama_attention_with_static_cache(args): - class LlamaAttentionBlock(nn.Module): - def __init__(self): - super().__init__() - self.config = LLAMA_CONFIG - self.attn = LlamaAttention( - config=self.config, - layer_idx=0 - ) - def forward(self, hidden_states, position_embeddings): - attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) - return attn_output - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - model = llama_model.model.layers[0].self_attn.to(DTYPE) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) - key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - import static_cache_v2 - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions=enabled_precisions, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - use_python_runtime=True) - - # Test Prefill - trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) - print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() - position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) - - is_causal = False - out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) - out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) - out_pyt = out_no_cache[:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - - -def test_llama_decoder(args): - - class LlamaDecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = LLAMA_CONFIG - self.decoder = LlamaDecoderLayer( - config=self.config, - layer_idx=0) - self.model = model - - def forward(self, hidden_states, position_embeddings): - return self.model(hidden_states, position_embeddings=position_embeddings) - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) - # llama3 - hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings], - enabled_precisions=enabled_precisions, - debug=args.debug, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing) - trt_output = trt_model(hidden_states, position_embeddings) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - -def test_llama_decoder_with_static_cache(args): - - class LlamaDecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = LLAMA_CONFIG - self.decoder = LlamaDecoderLayer( - config=self.config, - layer_idx=0) - self.model = model - def forward(self, hidden_states, position_embeddings): - return self.model(hidden_states, position_embeddings=position_embeddings) - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) - key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(hidden_states, position_embeddings) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) - import static_cache_v2 - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], - enabled_precisions=enabled_precisions, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - use_python_runtime=True) - - # Test Prefill - trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) - print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() - position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) - - is_causal = False - out_no_cache = model(hidden_states_full, position_embeddings_full) - - out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) - out_pyt = out_no_cache[0][:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - - -def test_llama_model(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = llama_model.model.to(DTYPE) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() - - pyt_output = model(input_ids, position_ids) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}) - kwarg_inputs = {"position_ids":position_ids} - from torch.export._trace import _export - ep = _export(model, args=(input_ids,), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes, strict=False, allow_complex_guards_as_runtime_asserts=True) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - arg_inputs=[], - kwarg_inputs=kwarg_inputs, - enabled_precisions=enabled_precisions, - disable_tf32=True, - debug=args.debug, - offload_module_to_cpu=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - use_python_runtime=True) - - trt_output = trt_model(input_ids, position_ids) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") - breakpoint() - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - -def test_llama_model_with_static_cache(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - model = llama_model.model.to(DTYPE) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() - key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(input_ids) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}) - kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids} - ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes) - - import static_cache_v2 - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - arg_inputs=[], - kwarg_inputs=kwarg_inputs, - enabled_precisions=enabled_precisions, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - use_python_runtime=True) - - # Test Prefill - trt_output, key_cache, value_cache = trt_model(input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal) - pyt_output = pyt_output.last_hidden_state - print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda() - position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda() - - # Concatenate the current hidden_states with the previous ones - input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1) - position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1) - is_causal = False - kwarg_inputs = {"input_ids":input_ids_full, "position_ids":position_ids_full} - out_no_cache = model(**kwarg_inputs) - - out_trt, key_cache, value_cache = trt_model(input_ids_curr, position_ids_curr, key_cache, value_cache, start_idx, end_idx, is_causal) - out_pyt = out_no_cache.last_hidden_state[:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - input_ids = input_ids_full - position_ids = position_ids_full - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--precision", - type=str, - default="FP16", - help="Precision (default: FP16)" - ) - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_llama_attention(args) - # test_llama_decoder(args) - test_llama_model(args) - # test_llama_attention_with_static_cache(args) - # test_llama_decoder_with_static_cache(args) - # test_llama_model_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_qwen2.5_components.py b/examples/dynamo/llm/test_qwen2.5_components.py deleted file mode 100644 index 37ffbc5dd5..0000000000 --- a/examples/dynamo/llm/test_qwen2.5_components.py +++ /dev/null @@ -1,173 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import TestCase -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers import AutoModelForCausalLM -import torch_tensorrt -from contextlib import nullcontext -import argparse -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from register_sdpa import * - -ATOL = 1e-5 -RTOL = 1e-5 - - -qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct" -qwen2_5_model = AutoModelForCausalLM.from_pretrained( - qwen2_5_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ).eval().cuda() -QWEN_CONFIG = qwen2_5_model.config - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") - -def test_qwen_apply_rotary_pos_emb(args): - class QwenApplyRotaryPosEmb(nn.Module): - def __init__(self): - super().__init__() - - def rotate_half(self, x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed - - def forward(self, q, k, cos, sin, unsqueeze_dim=1): - return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim) - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE) - # Shapes for Qwen 2.5 - q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() - k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() - cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda() - sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda() - - pyt_output = model(q, k, cos, sin) - - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len}) - ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes) - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[q, k, cos, sin], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug) - trt_output = trt_model(q, k, cos, sin) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - - -def test_qwen_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE) - # qwen2.5 - hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings, None) - - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings, None) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_qwen_apply_rotary_pos_emb(args) - test_qwen_attention(args) diff --git a/examples/dynamo/llm/test_qwen3.py b/examples/dynamo/llm/test_qwen3.py deleted file mode 100644 index e83419b717..0000000000 --- a/examples/dynamo/llm/test_qwen3.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import TestCase -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer -from transformers.models.qwen3.configuration_qwen3 import Qwen3Config -from transformers import AutoModelForCausalLM -import torch_tensorrt -from contextlib import nullcontext -import argparse -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from register_sdpa import * - -ATOL = 1e-5 -RTOL = 1e-5 - - -qwen3_model_name = "Qwen/Qwen3-0.6B" -qwen3_model = AutoModelForCausalLM.from_pretrained( - qwen3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ).eval().cuda() -QWEN_CONFIG = qwen3_model.config - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") - - -def test_qwen_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = qwen3_model.model.layers[0].self_attn.to(DTYPE) - # qwen2.5 - hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() - position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_embeddings, None) - - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug) - trt_output = trt_model(hidden_states, position_embeddings, None) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - -def test_qwen3_decoder(args): - - class QwenDecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = QWEN_CONFIG - self.model = model - def forward(self, hidden_states, position_ids, position_embeddings): - return self.model(hidden_states, position_ids=position_ids, position_embeddings=position_embeddings) - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - model = QwenDecoderLayerBlock(qwen3_model.model.layers[0].to(DTYPE)) - # qwen3 - hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() - position_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() - position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) - - pyt_output = model(hidden_states, position_ids, position_embeddings) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export(model, (hidden_states, position_ids, position_embeddings), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[hidden_states, position_ids, position_embeddings], - enabled_precisions={torch.float32}, - debug=args.debug) - trt_output = trt_model(hidden_states, position_ids, position_embeddings) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - -def test_qwen3_model(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - model = qwen3_model.model.to(DTYPE) - # qwen3 - input_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() - position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).cuda().unsqueeze(0) - - pyt_output = model(input_ids, position_ids) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}) - ep = torch.export.export(model, (input_ids, position_ids), dynamic_shapes=dynamic_shapes) - - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile(ep, - inputs=[input_ids, position_ids], - enabled_precisions={torch.float32}, - use_python_runtime=True, - disable_tf32=True, - debug=args.debug) - # breakpoint() - trt_output = trt_model(input_ids, position_ids) - - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") - print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[2] - trt_output[2]))}") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - arg_parser.add_argument("--precision", type=str, default="FP32", help="Precision to use in the model. Options: FP16, BF16, FP32") - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_qwen_attention(args) - # test_qwen3_decoder(args) - test_qwen3_model(args) diff --git a/examples/dynamo/llm/test_static_cache.py b/examples/dynamo/llm/test_static_cache.py deleted file mode 100644 index 52807f5e93..0000000000 --- a/examples/dynamo/llm/test_static_cache.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch -import torch.nn as nn -from torch.export import export -import torch_tensorrt -from contextlib import nullcontext -import argparse -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers import AutoModelForCausalLM -from torch_tensorrt.dynamo.lowering import ( - get_decompositions, - post_lowering, - pre_export_lowering, -) -import sys -import os - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -import register_sdpa - -ATOL = 1e-5 -RTOL = 1e-5 -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -class DynamicCacheModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, q, k, v, k1, v1, flag): - def true_fn(q, k, v, k1, v1): - k_new = torch.cat((k, k1), dim=2) - v_new = torch.cat((v, v1), dim=2) - return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) - - def false_fn(q, k, v, k1, v1): - return torch._C._nn.scaled_dot_product_attention(q, k, v) - - out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) - - return 2 * out - -class ModelNoCache(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, q, k, v): - return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) - -class StaticCacheModel(nn.Module): - def __init__(self): - super().__init__() - - # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): - # new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) - # new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) - # out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) - - # return out, new_key_cache, new_value_cache - - def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): - concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ] - concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) - new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) - new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) - out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) - - return out, new_key_cache, new_value_cache - - -def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, - is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: - """ - Eager implementation of SDPA - """ - import math - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - - if is_causal: - assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias = attn_mask + attn_bias - - if enable_gqa: - key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) - value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return attn_weight @ value - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") - - -def test_no_cache_model_with_torch_tensorrt(args): - """ - Test the no cache model - """ - with torch.inference_mode(): - model_no_cache = ModelNoCache().eval().cuda() - # q = torch.randn(1, 32, 6, 64).cuda() - # k = torch.randn(1, 32, 6, 64).cuda() - # v = torch.randn(1, 32, 6, 64).cuda() - q = torch.load("query.pt") - k = torch.load("key.pt") - v = torch.load("value.pt") - out_no_cache = model_no_cache(q, k, v) - out_eager = eager_sdpa(q, k, v, is_causal=True) - q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) - # Export the model - exported_program = torch.export.export( - model_no_cache, - args=(q, k, v), - dynamic_shapes=({2 : q_seq_len}, {2 : q_seq_len}, {2 : q_seq_len}), - strict=False - ) - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile( - exported_program, - inputs=[q, k, v], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=args.debug, - min_block_size=1, - ) - out_trt = trt_model(q, k, v) - - print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager") - print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt") - print_diff(out_eager, out_trt, "out_eager vs out_trt") - breakpoint() - - -def test_static_cache_model(args): - """ - Test the static cache model - """ - with torch.inference_mode(): - model_no_cache = ModelNoCache().eval().cuda() - model_static_cache = StaticCacheModel().eval().cuda() - q = torch.randn(1, 32, 2048, 64).cuda() - k = torch.randn(1, 32, 2048, 64).cuda() - v = torch.randn(1, 32, 2048, 64).cuda() - key_cache = torch.zeros(1, 32, 2176, 64).cuda() - value_cache = torch.zeros(1, 32, 2176, 64).cuda() - - # Test Prefill - start_idx = 0 - end_idx = 2048 - out_no_cache = model_no_cache(q, k, v) - out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) - assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - q_curr = torch.randn(1, 32, 1, 64).cuda() - k_curr = torch.randn(1, 32, 1, 64).cuda() - v_curr = torch.randn(1, 32, 1, 64).cuda() - - # Concatenate the current query, key, and value with the previous ones - q_full = torch.cat((q, q_curr), dim=2) - k_full = torch.cat((k, k_curr), dim=2) - v_full = torch.cat((v, v_curr), dim=2) - - out_no_cache = model_no_cache(q_full, k_full, v_full) - out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False) - - assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL) - q = q_full - k = k_full - v = v_full - print("============== test_static_cache passed ==============") - -def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args): - """ - Transform the graph module by adding key and value cache to the graph - """ - gm = exported_program.module() - # Post lower the model - settings = torch_tensorrt.dynamo.conversion.CompilationSettings( - enabled_precisions={torch.float32}, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - min_block_size=1, - ) - exported_program = pre_export_lowering(exported_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) - ) - - gm = exported_program.module() - gm = post_lowering(gm, settings) - - return gm - -def test_static_cache_lowering(args): - """ - Test static cache lowering pass applied to the model with no cache and run the graph module - and compare the output with the model with no cache - """ - import static_cache2 - - model_no_cache = ModelNoCache().eval().cuda() - q = torch.randn(1, 32, 2, 64).cuda() - k = torch.randn(1, 32, 2048, 64).cuda() - v = torch.randn(1, 32, 2048, 64).cuda() - key_cache = torch.zeros(1, 32, 2176, 64).cuda() - value_cache = torch.zeros(1, 32, 2176, 64).cuda() - - # Export the model - q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) - kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) - exported_program = export( - model_no_cache, - args=(q, k, v), - dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), - strict=False - ) - - gm = transform_gm_with_kv_cache(exported_program, args) - - # Test Prefill - start_idx = 0 - end_idx = 2048 - is_causal = True - q = torch.randn(1, 32, 2048, 64).cuda() - out_no_cache = model_no_cache(q, k, v) - out_pyt_cache, key_cache, value_cache = gm(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) - assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - is_causal = False - q_curr = torch.randn(1, 32, 1, 64).cuda() - k_curr = torch.randn(1, 32, 1, 64).cuda() - v_curr = torch.randn(1, 32, 1, 64).cuda() - # Concatenate the current query, key, and value with the previous ones - q_full = torch.cat((q, q_curr), dim=2) - k_full = torch.cat((k, k_curr), dim=2) - v_full = torch.cat((v, v_curr), dim=2) - - out_no_cache = model_no_cache(q_full, k_full, v_full) - out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, key_cache, value_cache, start_idx, end_idx, is_causal) - assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL) - q = q_full - k = k_full - v = v_full - - print("============== test_static_cache_lowering passed ==============") - -def test_static_cache_export(args): - """ - Test the static cache model export - """ - model_static_cache = StaticCacheModel().eval().cuda() - q = torch.randn(1, 32, 2048, 64).cuda() - k = torch.randn(1, 32, 2048, 64).cuda() - v = torch.randn(1, 32, 2048, 64).cuda() - key_cache = torch.zeros(1, 32, 2176, 64).cuda() - value_cache = torch.zeros(1, 32, 2176, 64).cuda() - # Test Prefill - start_idx = 0 - end_idx = 2048 - is_causal = True - # Export the model - seq_len = torch.export.Dim("seq_len", min=2, max=2048) - seq_len_dyn_dim = {2 : seq_len} - exported_program = export( - model_static_cache, - args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), - dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None), - strict=False - ) - - -def test_static_cache_with_torch_tensorrt(args): - """ - Test the static cache model with torch_tensorrt - """ - import static_cache_v2 - - model_no_cache = ModelNoCache().eval().cuda() - q = torch.randn(1, 32, 2, 64).cuda() - k = torch.randn(1, 32, 2048, 64).cuda() - v = torch.randn(1, 32, 2048, 64).cuda() - key_cache = torch.zeros(1, 32, 2176, 64).cuda() - value_cache = torch.zeros(1, 32, 2176, 64).cuda() - - # Export the model - q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) - kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) - exported_program = export( - model_no_cache, - args=(q, k, v), - dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), - strict=False - ) - with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): - trt_model = torch_tensorrt.dynamo.compile( - exported_program, - inputs=[q, k, v], - enabled_precisions={torch.float32}, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - min_block_size=1, - ) - - start_idx = 0 - end_idx = 2048 - is_causal = True - q = torch.randn(1, 32, 2048, 64).cuda() - # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) - out_no_cache = model_no_cache(q, k, v) - out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) - - assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match" - assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match" - assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match" - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - q_curr = torch.randn(1, 32, 1, 64).cuda() - k_curr = torch.randn(1, 32, 1, 64).cuda() - v_curr = torch.randn(1, 32, 1, 64).cuda() - # Concatenate the current query, key, and value with the previous ones - q_full = torch.cat((q, q_curr), dim=2) - k_full = torch.cat((k, k_curr), dim=2) - v_full = torch.cat((v, v_curr), dim=2) - is_causal = True - out_no_cache = model_no_cache(q_full, k_full, v_full) - out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) - # breakpoint() - # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") - # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") - # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}") - assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}" - assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}" - assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}" - q = q_full - k = k_full - v = v_full - - print("============== test_static_cache_with_torch_tensorrt passed ==============") - - -def main(): - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", - action="store_true", - help="Enable debug (default: False)" - ) - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_no_cache_model_with_torch_tensorrt(args) - # test_static_cache_model(args) - # test_static_cache_lowering(args) - test_static_cache_with_torch_tensorrt(args) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/dynamo/llm/utils.py b/examples/dynamo/llm/utils.py deleted file mode 100644 index c43f90acc5..0000000000 --- a/examples/dynamo/llm/utils.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch -from transformers import StoppingCriteriaList -from transformers.generation.stopping_criteria import ( - EosTokenCriteria, - MaxLengthCriteria, -) -import numpy as np -import copy -import timeit - -def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): - """ - Exports the LLM model into an ExportedProgram with dynamic shapes. - In the case of guard failures due to some PyTorch kernel implements, we also - try to re-export the graph by expressing them as runtime assert nodes - """ - with torch.no_grad(): - # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 - seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) - position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) - try: - print("Trying to export the model using torch.export.export()..") - # strict=False only enables aotautograd tracing and excludes dynamo. - ep = torch.export.export( - model, args=(inputs,), kwargs={"position_ids":position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False - ) - except: - print( - "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" - ) - # This API is used to express the constraint violation guards as asserts in the graph. - ep = torch.export._trace._export( - model, - args=(inputs,), - kwargs={"position_ids":position_ids}, - dynamic_shapes=({1: seq_len}, {1: seq_len}), - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - - return ep - -def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): - """ - Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. - - This function identifies placeholder nodes in the graph that represent KV cache tensors, - and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. - - Args: - model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders - - Returns: - tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph - """ - # placeholder nodes are expected to be in the following order: - # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx - placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] - # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors. - kv_cache_inputs = placeholder_nodes[2:-2] - zeroed_kv_cache_inputs = [] - for input in kv_cache_inputs: - zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) - - return tuple(zeroed_kv_cache_inputs) - -def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule): - """ - Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache. - - This function identifies placeholder nodes in the graph that represent KV cache tensors, - and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. - - Args: - model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders - - Returns: - tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph - """ - # placeholder nodes are expected to be in the following order: - # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx - placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] - # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors. - kv_cache_inputs = placeholder_nodes[2:-1] - zeroed_kv_cache_inputs = [] - for input in kv_cache_inputs: - zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) - - return tuple(zeroed_kv_cache_inputs) - - -def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): - """ - Greedy decoding of the model. This generates up to max_tokens. - """ - stopping_criteria = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=max_output_seq_length), - EosTokenCriteria(eos_token_id=eos_token_id), - ] - ) - isl = input_seq.shape[1] - osl = max_output_seq_length - isl - - num_tokens_generated = 0 - while num_tokens_generated < osl: - position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() - outputs = model(input_seq, position_ids=position_ids) - logits = outputs.logits - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1) - input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) - num_tokens_generated += 1 - # TODO: Handle batch in this check - if not benchmark and stopping_criteria(input_seq, logits).item(): - break - - return input_seq - -def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id): - """ - Greedy decoding of the model with static KV cache. - """ - start_idx = 0 - end_idx = input_seq.shape[1] - position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() - output_seq = input_seq.clone() - # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL - num_tokens_generated = 0 - kv_cache = get_zeroed_static_cache_inputs(model) - while end_idx < max_output_seq_length: - position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids - input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) - logits_keys_values = model(*input_signature) - num_tokens_generated += 1 - logits = logits_keys_values[0] - kv_cache = logits_keys_values[1:] - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) - output_seq = torch.cat([output_seq, next_tokens], dim=-1) - input_seq = next_tokens - start_idx = end_idx - end_idx = start_idx + 1 - return output_seq - -def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id): - """ - Greedy decoding of the model with dynamic KV cache. - """ - position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() - output_seq = input_seq.clone() - num_output_tokens = max_output_seq_length - input_seq.shape[1] - num_tokens_generated = 0 - kv_cache = get_zeroed_dynamic_cache_inputs(model) - last_position_id = position_ids[-1, -1].item() - breakpoint() - while num_tokens_generated < num_output_tokens: - is_generate = False if input_seq.shape[1] > 1 else True - position_ids = torch.tensor([[last_position_id+1]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids - input_signature = (input_seq, position_ids, *kv_cache, is_generate) - logits_keys_values = model(*input_signature) - num_tokens_generated += 1 - logits = logits_keys_values[0] - kv_cache = logits_keys_values[1:] - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) - output_seq = torch.cat([output_seq, next_tokens], dim=-1) - input_seq = next_tokens - last_position_id += 1 - return output_seq - - -def time_generate( - generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 -): - """ - Measure the time for generating a sentence over certain number of iterations - """ - timings = [] - for _ in range(iterations): - start_time = timeit.default_timer() - _ = generate_fn( - model, inputs, output_seq_length, eos_token_id - ) - torch.cuda.synchronize() - end_time = timeit.default_timer() - timings.append(end_time - start_time) - - return timings - - -def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): - """ - Records different timing stats and adds it to the result - """ - times = np.array(timings) - speeds = batch_size / times - time_mean = np.mean(times).item() - time_med = np.median(times).item() - time_99th = np.percentile(times, 99).item() - time_std = np.std(times, ddof=0).item() - speed_mean = np.mean(speeds).item() - speed_med = np.median(speeds).item() - - stats = { - "Backend": backend, - "Precision": precision, - "Batch size": batch_size, - "Median(FPS)": speed_med, - "Mean(FPS)": speed_mean, - "Median-Latency(ms)": time_med * 1000, - "Mean-Latency(ms)": time_mean * 1000, - "Latency-StdDev(ms)": time_std * 1000, - "Compile Time(s)": compile_time_s, - } - return stats \ No newline at end of file diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py deleted file mode 100644 index 906673a806..0000000000 --- a/examples/dynamo/register_sdpa.py +++ /dev/null @@ -1,122 +0,0 @@ -import copy -import logging -import operator -from typing import Callable, Sequence, Tuple - -import torch -from sdpa_converter import * -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check -from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - -# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention -# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) -TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_efficient_attention.default, None -) -TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_flash_attention.default, None -) - -REPLACEABLE_ATEN_OPS = { - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, -} - - -@_aten_lowering_pass -def replace_variants_of_sdpa( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT - """ - attn_mask = None - is_causal = True - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_bias, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = ( - node.args - ) - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - - logger.warning(f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations.") - modified_input_args = (query, key, value, None, dropout_p, True) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None), "use_fp32_acc": settings.use_fp32_acc}, - ) - - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph - clean_up_graph_after_modifications(gm) - - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) - return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py deleted file mode 100644 index c60ad915dd..0000000000 --- a/examples/dynamo/sdpa_converter.py +++ /dev/null @@ -1,200 +0,0 @@ -import logging -import math -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import tensorrt as trt -import torch -import torch_tensorrt -from torch.fx.node import Target -from torch_tensorrt._enums import dtype -from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - cast_trt_tensor, - get_trt_tensor, -) -from torch_tensorrt.fx.types import TRTTensor - -logger = logging.getLogger(__name__) - - -def tril( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - row: TRTTensor, - col: TRTTensor, -) -> TRTTensor: - row_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 - ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - - col_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 - ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] - ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor - ) - return mask - - -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( - torch.nn.functional.scaled_dot_product_attention, - enabled=True, - supports_dynamic_shapes=True, -) -def scaled_dot_product_attention( - ctx: torch_tensorrt.dynamo.conversion.ConversionContext, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: str, -) -> TRTTensor: - # TODO: Handle attn_mask and is_causal arguments in the future - query, key, value, attn_mask, dropout_p, is_causal = args - - # TODO: remove this once we have a better way to handle the causal mask - scale = kwargs.get("scale", None) - source_ir = SourceIR.ATEN - is_causal = True - # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - use_fp32_acc = kwargs.get("use_fp32_acc", False) - query_dtype = query.dtype - - if scale is None: - scale = query.shape[-1] - if scale < 0: - # dynamic shape - scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) - sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) - else: - # static shape - sqrt_scaled = math.sqrt(scale) - key = impl.elementwise.div( - ctx, - target, - source_ir, - name + "_scale", - key, - sqrt_scaled, - ) - else: - key = impl.elementwise.mul( - ctx, - target, - source_ir, - name + "_scale", - key, - scale, - ) - - if use_fp32_acc and query_dtype == trt.float16: - query = cast_trt_tensor( - ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir - ) - key = cast_trt_tensor( - ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir - ) - - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) - - if use_fp32_acc: - mm = cast_trt_tensor( - ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir - ) - - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, 2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq( - ctx, target, source_ir, name + "_eq", L, S - ) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) - - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) - - softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False - ) - if use_fp32_acc: - softmax = cast_trt_tensor( - ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir - ) - value = cast_trt_tensor( - ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir - ) - out = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_out", - softmax, - value, - ) - if use_fp32_acc: - out = cast_trt_tensor( - ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir - ) - - return out diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py deleted file mode 100644 index 4d34c58de4..0000000000 --- a/examples/dynamo/torch_export_gpt2.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -.. _torch_export_gpt2: - -Compiling GPT2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. -""" - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import export_llm, generate - -# %% - -# Define the parameters and initialize the model -MAX_TOKENS = 32 -DEVICE = torch.device("cuda:0") - -# Define the GPT2 model from hugging face -# kv_cache is not supported in Torch-TRT currently. -# CPU is used here so that GPU memory is reserved for TRT compilation. -with torch.no_grad(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = ( - AutoModelForCausalLM.from_pretrained( - "gpt2", - pad_token_id=tokenizer.eos_token_id, - use_cache=False, - attn_implementation="eager", - ) - .eval() - .half() - ) - -# %% -# Tokenize a sample input prompt and get pytorch model outputs -prompt = "I enjoy walking with my cute dog" -model_inputs = tokenizer(prompt, return_tensors="pt") -input_ids = model_inputs["input_ids"] - -# Auto-regressive generation loop for greedy decoding using PyTorch model -# We use a custom generate function which is very similar to the huggingface one. -pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - - -# %% -# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Export the GPT2 model into an ExportedProgram which is input of TRT compilation -# To compile the model in FP16, we do the following -# 1) Cast the model to FP16 via model.half() -# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation -# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) -gpt2_ep = export_llm(model, input_ids, max_seq_len=1024) -trt_model = torch_tensorrt.dynamo.compile( - gpt2_ep, - inputs=[input_ids], - enabled_precisions={torch.float32}, - truncate_double=True, - device=DEVICE, - disable_tf32=True, - use_explicit_typing=True, - use_fp32_acc=True, -) - -# Auto-regressive generation loop for greedy decoding using TensorRT model -# We use a custom generate function which is very similar to the huggingface one. -# Move inputs to GPU -input_ids = input_ids.to(DEVICE) -trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Decode the output sentences of PyTorch and TensorRT -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -print("=============================") -print( - "Pytorch model generated text: ", - tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), -) -print("=============================") -print( - "TensorRT model generated text: ", - tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), -) - -# Prompt : What is parallel programming ? - -# ============================= -# Pytorch model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that - -# ============================= -# TensorRT model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py deleted file mode 100644 index 2f3e3cba43..0000000000 --- a/examples/dynamo/torch_export_llama2.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -.. _torch_export_llama2: - -Compiling Llama2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model. -""" - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import export_llm, generate - -# %% -# Define the parameters and initialize the model -MAX_TOKENS = 32 -DEVICE = torch.device("cuda:0") - -# Define the Llama2 model from hugging face -# kv_cache is not supported in Torch-TRT currently. -# CPU is used here so that GPU memory is reserved for TRT compilation. -llama_path = "meta-llama/Llama-2-7b-chat-hf" -with torch.no_grad(): - model = ( - AutoModelForCausalLM.from_pretrained( - llama_path, use_cache=False, attn_implementation="eager" - ) - .eval() - .half() - ) - -tokenizer = AutoTokenizer.from_pretrained(llama_path) - -# %% -# Tokenize a sample input prompt and get pytorch model outputs -prompt = "What is dynamic programming?" -model_inputs = tokenizer(prompt, return_tensors="pt") -input_ids = model_inputs.input_ids - -# Auto-regressive generation loop for greedy decoding using PyTorch model -# We use a custom generate function which is very similar to the huggingface one. -pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Export the llama2 model into an ExportedProgram which is input of TRT compilation -# To compile the model in FP16, we do the following -# 1) Cast the model to FP16 via model.half() -# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation -# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) -llama2_ep = export_llm(model, input_ids, max_seq_len=64) -trt_model = torch_tensorrt.dynamo.compile( - llama2_ep, - inputs=[input_ids], - enabled_precisions={torch.float32}, - truncate_double=True, - device=DEVICE, - disable_tf32=True, - use_explicit_typing=True, - use_fp32_acc=True, -) - -# Auto-regressive generation loop for greedy decoding using TensorRT model -# We use a custom generate function which is very similar to the huggingface one. -# Move inputs to GPU -input_ids = input_ids.to(DEVICE) -trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Decode the output sentences of PyTorch and TensorRT -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -print("=============================") -print( - "Pytorch model generated text: ", - tokenizer.batch_decode( - pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0], -) -print("=============================") -print( - "TensorRT model generated text: ", - tokenizer.batch_decode( - trt_gen_tokens, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - )[0], -) - - -# Prompt : What is dynamic programming? - -# ============================= -# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and - -# ============================= -# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and From ecf88d1759abbb780a981b6df31b0f8007c52b05 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Jun 2025 22:50:10 +0000 Subject: [PATCH 23/30] chore: move code to tools/llm --- .../dynamo/lowering/_decomposition_groups.py | 1 - .../lowering/passes/_aten_lowering_pass.py | 3 +- .../runtime/_PythonTorchTensorRTModule.py | 13 +- tools/llm/cache_utils.py | 202 ++++++ tools/llm/dynamic_cache.py | 215 +++++++ tools/llm/llm_pyt_benchmark.py | 83 +++ tools/llm/register_sdpa.py | 125 ++++ tools/llm/run_llm.py | 344 ++++++++++ tools/llm/run_vlm.py | 333 ++++++++++ tools/llm/sdpa_converter.py | 196 ++++++ tools/llm/static_cache_v1.py | 277 ++++++++ tools/llm/static_cache_v2.py | 290 +++++++++ tools/llm/test_gemma.py | 389 +++++++++++ tools/llm/test_llama_components.py | 603 ++++++++++++++++++ tools/llm/test_qwen2.5_components.py | 193 ++++++ tools/llm/test_qwen3.py | 223 +++++++ tools/llm/test_static_cache.py | 468 ++++++++++++++ tools/llm/utils.py | 244 +++++++ 18 files changed, 4194 insertions(+), 8 deletions(-) create mode 100644 tools/llm/cache_utils.py create mode 100644 tools/llm/dynamic_cache.py create mode 100644 tools/llm/llm_pyt_benchmark.py create mode 100644 tools/llm/register_sdpa.py create mode 100644 tools/llm/run_llm.py create mode 100644 tools/llm/run_vlm.py create mode 100644 tools/llm/sdpa_converter.py create mode 100644 tools/llm/static_cache_v1.py create mode 100644 tools/llm/static_cache_v2.py create mode 100644 tools/llm/test_gemma.py create mode 100644 tools/llm/test_llama_components.py create mode 100644 tools/llm/test_qwen2.5_components.py create mode 100644 tools/llm/test_qwen3.py create mode 100644 tools/llm/test_static_cache.py create mode 100644 tools/llm/utils.py diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 6df05f6940..825be75076 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -171,7 +171,6 @@ aten.upsample_bilinear2d.vec, aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, - aten.linear, } diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 6e2019ad71..2ecc45ecf3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,7 +6,6 @@ from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul -from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast @@ -26,7 +25,7 @@ replace_max_pool_with_indices, remove_assert_nodes, accumulate_fp32_matmul, - # remove_num_users_is_0_nodes, + remove_num_users_is_0_nodes, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index fe4b781505..84711f154e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -743,11 +743,14 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [ - t if isinstance(t, torch.Tensor) else torch.tensor(t) - for t in inputs - ] - new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) + tensor_inputs = [] + for t in inputs: + if not isinstance(t, torch.Tensor): + return True + tensor_inputs.append(t) + new_shape_key = "".join( + str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + ) # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph diff --git a/tools/llm/cache_utils.py b/tools/llm/cache_utils.py new file mode 100644 index 0000000000..7089d9a220 --- /dev/null +++ b/tools/llm/cache_utils.py @@ -0,0 +1,202 @@ +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import tensorrt +import torch +import torch_tensorrt +from torch._export.utils import _detect_fake_mode_from_gm +from torch._ops import OpOverloadPacket +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Target +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._pytree import _LEAF_SPEC + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True +) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output + + +def get_kv_nodes(gm): + """ + Get the key and value nodes from the graph. + """ + kv_nodes = [] + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch._C._nn.scaled_dot_product_attention + ): + q_node, k_node, v_node = node.args[:3] + kv_nodes.append((k_node, v_node)) + return kv_nodes + + +def get_random_tensor_from_node(node: Node) -> torch.Tensor: + """ + Creates a random tensor based on the shape information in a node's metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + node: A torch.fx.Node object with metadata containing tensor information + + Returns: + A random tensor with shape matching the node's metadata, or None if no valid + tensor information is found + """ + if "val" not in node.meta: + raise ValueError( + f"No tensor information found in node metadata for node: {node}" + ) + + fake_tensor = node.meta["val"] + shape = [] + + # Iterate through each dimension and handle symbolic dimensions + for dim in fake_tensor.shape: + if isinstance(dim, torch.SymInt): + # Extract the maximum value from the shape environment + max_val = dim.node.hint + shape.append(max_val) + else: + shape.append(dim) + + # Create a random tensor with the determined shape + dtype = fake_tensor.dtype + device = fake_tensor.device + random_tensor = torch.rand(shape, dtype=dtype, device=device) + + return random_tensor + + +def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: + """ + Creates random tensors based on the shape information in node metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + nodes: List of torch.fx.Node objects with metadata + + Returns: + List of random tensors with shapes matching the nodes' metadata + """ + random_tensors = [] + + for node in nodes: + if isinstance(node, Node): + node_tensor = get_random_tensor_from_node(node) + elif isinstance(node, tuple): + node_tensor_list = [] + for n in node: + random_tensor = get_random_tensor_from_node(n) + node_tensor_list.append(random_tensor) + node_tensor = tuple(node_tensor_list) + + random_tensors.append(node_tensor) + + return random_tensors + + +def add_graph_input( + gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None +) -> Node: + """Add a graph input to the given GraphModule and return the newly created node. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the input to. + name (str): The name of the input. + val (torch.Tensor): An example tensor to use for the input. + dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] + """ + # check that no dynamic shape is provided... + if dynamic_shape: + raise NotImplementedError("Dynamic shape not supported for adding graph inputs") + + # extract graph and input spec + graph: Graph = gm.graph + + in_spec = graph._codegen.pytree_info.in_spec + in_spec_for_args = in_spec.children_specs[0] + orig_args = graph._codegen.pytree_info.orig_args + assert in_spec_for_args.type is tuple + + # insert input node after currently last input node + node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] + with graph.inserting_after(node_last_input): + in_node = graph.placeholder(name) + in_spec_for_args.children_specs.append(_LEAF_SPEC) + orig_args.append(f"arg_{name}") + + # update pytree info recursively with __post_init__ starting at leaves + def call_post_init(spec): + for child_spec in spec.children_specs: + call_post_init(child_spec) + spec.__post_init__() + + call_post_init(in_spec) + + # set fake tensor information if all required information is available + fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) + if fake_mode and val is not None and isinstance(val, torch.Tensor): + if isinstance(val, FakeTensor): + fake_tensor = val + else: + fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) + in_node.meta["val"] = fake_tensor + in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) + + # return new node... + return in_node + + +def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: + """Check if the node is a call to one of the ops.""" + if node.op != "call_function": + return False + # check if it's a single op that's provided + if isinstance(ops, OpOverloadPacket): + ops = [ops] + + # check if it's the op itself instead of an overload + if any(node.target == op for op in ops): + return True + + return False + + +def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: + input_nodes: List[Node] = graph.find_nodes(op="placeholder") + output_nodes: List[Node] = graph.find_nodes(op="output") + return (input_nodes, output_nodes) diff --git a/tools/llm/dynamic_cache.py b/tools/llm/dynamic_cache.py new file mode 100644 index 0000000000..b45ebb6d43 --- /dev/null +++ b/tools/llm/dynamic_cache.py @@ -0,0 +1,215 @@ +import logging +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from cache_utils import ( + add_graph_input, + create_random_output_tensors, + get_kv_nodes, + is_op, +) +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info + +logger = logging.getLogger(__name__) + + +def add_kv_as_outputs(gm): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + # list of MHA kernels we would want to detect and replace + mha_ops = { + torch._C._nn.scaled_dot_product_attention, + } + + # Find all SDPA nodes in the graph + mha_nodes = [] + for node in gm.graph.nodes: + if is_op(node, mha_ops): + mha_nodes.append(node) + + # Iterate through each MHA node to extract shape information + for mha_node in mha_nodes: + if "val" in mha_node.meta and len(mha_node.args) >= 3: + # Get the input nodes (query, key, value) + q_node, k_node, v_node = mha_node.args[:3] + + # Add the copy nodes as outputs to the graph + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + ((k_node, v_node),) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs, (k_node, v_node)) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors and index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add is_generate as input + is_generate_input = add_graph_input(gm, "is_generate", True) + is_generate_input.meta["val"] = torch.tensor(True) + + return kv_inputs, is_generate_input + + +def insert_torch_cond_before_sdpa( + gm, + incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], + is_generate_input: torch.Tensor, +): + """ + Insert a torch.cond operation before each scaled_dot_product_attention operation. + + Args: + gm: The FX GraphModule to modify + + Returns: + The modified GraphModule + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch._C._nn.scaled_dot_product_attention + ): + sdpa_nodes.append(node) + + # For each SDPA node, insert a torch.cond operation before it + for idx, sdpa_node in enumerate(sdpa_nodes): + + with gm.graph.inserting_before(sdpa_node): + # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + # Create nodes for concatenating k with incoming_key and v with incoming_value + concatenated_k_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=( + [incoming_key, k_node], + 2, + ), # Concatenate along sequence length dimension + kwargs={}, + ) + concatenated_v_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=( + [incoming_value, v_node], + 2, + ), # Concatenate along sequence length dimension + kwargs={}, + ) + + # Create the torch.cond node + cond_k_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(is_generate_input, concatenated_k_node, k_node), + ) + + cond_v_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(is_generate_input, concatenated_v_node, v_node), + ) + + sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] + + return gm + + +@_aten_lowering_pass +def insert_dynamic_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert FlashInfer MHA + KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + + # Add static key and value as inputs to the graph + kv_inputs, is_generate_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm) + + # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases + gm = insert_torch_cond_before_sdpa(gm, kv_inputs, is_generate_input) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm diff --git a/tools/llm/llm_pyt_benchmark.py b/tools/llm/llm_pyt_benchmark.py new file mode 100644 index 0000000000..f3d68a951a --- /dev/null +++ b/tools/llm/llm_pyt_benchmark.py @@ -0,0 +1,83 @@ +import timeit + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +USE_CACHE = True +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_NEW_TOKENS = 128 + + +def main(): + # Initialize model and tokenizer + print("Loading model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, torch_dtype=torch.float16, use_cache=USE_CACHE, device_map="auto" + ) + # model.generation_config.cache_implementation = "static" + # model.forward = torch.compile(model.forward) + + # Prepare input prompt + word = "What" + # Tokenize the word + word_ids = tokenizer(word, return_tensors="pt").input_ids[ + 0 + ] # Get the first (and only) sequence + # Repeat the token 2048 times + input_ids = ( + word_ids.repeat(1024).unsqueeze(0).to(model.device) + ) # Add batch dimension and move to device + print(f"Input tensor shape: {input_ids.shape}") + + # # Warm-up pass + print("Running warm-up pass...") + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE, + ) + + # Benchmark loop + print("Running benchmark...") + num_iterations = 10 + total_time = 0 + timings = [] + + for i in range(num_iterations): + start_time = timeit.default_timer() + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE, + ) + end_time = timeit.default_timer() + generation_time = end_time - start_time + total_time += generation_time + timings.append(generation_time) + + # Decode and print first iteration output + # if i == 0: + # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # print("\nFirst generation output:") + # print(output_text) + + # Calculate and print statistics + average_time = total_time / num_iterations + print(f"\nPerformance Statistics:") + print( + f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds" + ) + print(f"Average tokens per second: {100/average_time:.2f}") + print("\nIndividual timings (ms):") + for i, t in enumerate(timings): + print(f"Iteration {i+1}: {t*1000:.2f}") + + +if __name__ == "__main__": + main() diff --git a/tools/llm/register_sdpa.py b/tools/llm/register_sdpa.py new file mode 100644 index 0000000000..c3c76e0f2d --- /dev/null +++ b/tools/llm/register_sdpa.py @@ -0,0 +1,125 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from sdpa_converter import * +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + +# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention +# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default, None +) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_flash_attention.default, None +) + +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +@_aten_lowering_pass +def replace_variants_of_sdpa( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace scaled_dot_product_attention with an equivalent + implementation which can be accurately converted to TRT + """ + attn_mask = None + is_causal = True + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + if ( + node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ): + if len(node.args) == 7: + ( + query, + key, + value, + attn_bias, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif ( + node.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ): + if len(node.args) == 6: + query, key, value, dropout_p, is_causal, return_debug_mask = ( + node.args + ) + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + + logger.warning( + f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + ) + modified_input_args = (query, key, value, None, dropout_p, True) + # Create a new node with torch.nn.functional.scaled_dot_product_attention + # The input args is (query, key, value, is_causal). kwargs has scale + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + }, + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + + # Clean up the graph + clean_up_graph_after_modifications(gm) + + logger.debug( + "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" + ) + return gm diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py new file mode 100644 index 0000000000..de8ab9d92a --- /dev/null +++ b/tools/llm/run_llm.py @@ -0,0 +1,344 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the dynamo backend +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" + +import argparse +import copy +import os +import timeit +from contextlib import nullcontext + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from register_sdpa import * +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import ( + export_llm, + generate, + generate_with_dynamic_cache, + generate_with_static_cache, + recordStats, + time_generate, +) + +DEVICE = torch.device("cuda:0") + + +def get_model(args): + with torch.no_grad(): + # Supported list of models: + # - meta-llama/Llama-3.2-1B-Instruct + # - meta-llama/Llama-3.2-3B-Instruct + # - meta-llama/Llama-3.1-8B-Instruct + # - Qwen/Qwen2.5-1.5B-Instruct + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + # num_hidden_layers=1 + ) + .eval() + .cuda() + ) + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + + return model + + +def compile_torchtrt(model, input_ids, args): + max_seq_len = input_ids.shape[1] + args.num_tokens + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +def measure_perf(trt_model, input_signature, backend_name): + # Measure average time for 10 iterations + import timeit + + import numpy as np + + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + # print(f"Iteration {i+1}: {iter_time:.4f} seconds") + + avg_time = total_time / iterations + print( + f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" + ) + print( + f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" + ) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="no. of output tokens to be generated", + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length used for benchmarking", + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)", + ) + arg_parser.add_argument( + "--cache", + type=str, + default="", + help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", + ) + arg_parser.add_argument( + "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmark (default: False)" + ) + + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + # Pyt + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + + if args.enable_pytorch_run: + pyt_gen_tokens = generate( + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + elif args.cache == "dynamic": + import dynamic_cache + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + elif args.cache == "dynamic": + trt_gen_tokens = generate_with_dynamic_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate_with_dynamic_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = recordStats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..470b0e6d99 --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,333 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the dynamo backend +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" + +import argparse +import copy +import os +import sys +import timeit +from contextlib import nullcontext + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import ( + export_llm, + generate, + generate_with_kv_cache, + recordStats, + time_generate, +) + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +DEVICE = torch.device("cuda:0") + + +def get_model(args): + with torch.no_grad(): + # Supported list of models: + # - meta-llama/Llama-3.2-1B-Instruct + # - meta-llama/Llama-3.2-3B-Instruct + # - meta-llama/Llama-3.1-8B-Instruct + # - Qwen/Qwen2.5-1.5B-Instruct + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=2, + ) + .eval() + .cuda() + ) + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + + return model + + +def compile_torchtrt(model, input_ids, args): + max_seq_len = input_ids.shape[1] + args.num_tokens + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +def measure_perf(trt_model, input_signature, backend_name): + # Measure average time for 10 iterations + import timeit + + import numpy as np + + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + # print(f"Iteration {i+1}: {iter_time:.4f} seconds") + + avg_time = total_time / iterations + print( + f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" + ) + print( + f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" + ) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="no. of output tokens to be generated", + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length used for benchmarking", + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)", + ) + arg_parser.add_argument( + "--cache", + type=str, + default="", + help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", + ) + arg_parser.add_argument( + "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmark (default: False)" + ) + + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + # Pyt + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + if args.enable_pytorch_run: + pyt_gen_tokens = generate( + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + elif args.cache == "dynamic": + import dynamic_cache + + trt_model = compile_torchtrt(model, input_ids, args) + + if ( + args.cache == "static_v1" + or args.cache == "static_v2" + or args.cache == "dynamic" + ): + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_kv_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_kv_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = recordStats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/tools/llm/sdpa_converter.py b/tools/llm/sdpa_converter.py new file mode 100644 index 0000000000..47083c7b48 --- /dev/null +++ b/tools/llm/sdpa_converter.py @@ -0,0 +1,196 @@ +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, +) -> TRTTensor: + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + row_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] + ) + + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + col_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + ) + + mask = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + ) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, attn_mask, dropout_p, is_causal = args + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + is_causal = True + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + use_fp32_acc = kwargs.get("use_fp32_acc", False) + query_dtype = query.dtype + + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + key = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + key, + sqrt_scaled, + ) + else: + key = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + key, + scale, + ) + + if use_fp32_acc and query_dtype == trt.float16: + query = cast_trt_tensor( + ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir + ) + key = cast_trt_tensor( + ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir + ) + + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if use_fp32_acc: + mm = cast_trt_tensor( + ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir + ) + + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ) + if use_fp32_acc: + softmax = cast_trt_tensor( + ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir + ) + value = cast_trt_tensor( + ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + if use_fp32_acc: + out = cast_trt_tensor( + ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir + ) + + return out diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py new file mode 100644 index 0000000000..a87495953d --- /dev/null +++ b/tools/llm/static_cache_v1.py @@ -0,0 +1,277 @@ +import logging +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from torch.fx import Node +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info + +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + input_ids_meta = input_nodes[0].meta["val"] + seq_len = input_ids_meta.shape[1] + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + + +def insert_kv_slicing_before_sdpa( + gm, + incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], + start_idx_input: Node, + end_idx_input: Node, +): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert ( + len(sdpa_node.args) == 6 + ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip( + [incoming_key, incoming_value], [k_node, v_node] + ): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={}, + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={}, + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={}, + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={}, + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={}, + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={}, + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={}, + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={}, + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={}, + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={}, + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={}, + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={}, + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + ( + attn_mask, + dropout_p, + True, + ) + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v1( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa( + gm, kv_inputs, start_idx_input, end_idx_input + ) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py new file mode 100644 index 0000000000..ad386d39f2 --- /dev/null +++ b/tools/llm/static_cache_v2.py @@ -0,0 +1,290 @@ +import logging +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from torch.fx import Node +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info + +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + # Get the third last input which should be the last value cache node and store the max_seq_len + input_ids_meta = input_nodes[-3].meta["val"] + seq_len = input_ids_meta.shape[2] + + if isinstance(seq_len, torch.SymInt): + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + else: + max_seq_len = seq_len + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + + +def create_kv_cache_update_nodes( + gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input +): + """ + Create slicing and concatenation nodes for KV cache update. + + This function creates the necessary slicing and concatenation nodes to update the KV cache + during the generation process. It takes the SDPA node, the current KV cache node, and the + incoming KV cache node as input. + Returns: + for a particular SDPA node, a tuple containing: + - List of new current KV nodes + - List of updated incoming KV cache nodes + + """ + + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={}, + ) + slice_2 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={}, + ) + slice_4 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={} + ) + # Concat key_cache[:,:,:start_idx,:] with current key (k) + concat_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_kv_node], 2), + kwargs={}, + ) + + # =============================================== # + # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={}, + ) + slice_6 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={}, + ) + slice_8 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + new_incoming_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([concat_keys_or_values, slice_8], 2), + kwargs={}, + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + new_incoming_keys_or_values.meta.update(incoming_kv_node.meta) + + return concat_keys_or_values, new_incoming_keys_or_values + + +def insert_kv_slicing_before_sdpa( + gm, + incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], + start_idx_input: Node, + end_idx_input: Node, +): + """ + Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert ( + len(sdpa_node.args) == 6 + ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + # For keys + new_current_key_node, new_incoming_key_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input + ) + ) + # For values + new_current_value_node, new_incoming_value_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input + ) + ) + + # Store the KV cache nodes for the current SDPA node + kv_cache_for_graph.extend( + [new_incoming_key_cache_node, new_incoming_value_cache_node] + ) + + # Update the SDPA node arguments with current key and value nodes + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + ( + attn_mask, + dropout_p, + True, + ) + + # kv_cache_for_graph.extend([k_node, v_node]) + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v2( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa( + gm, kv_inputs, start_idx_input, end_idx_input + ) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm diff --git a/tools/llm/test_gemma.py b/tools/llm/test_gemma.py new file mode 100644 index 0000000000..a5dc420ade --- /dev/null +++ b/tools/llm/test_gemma.py @@ -0,0 +1,389 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.gemma3.configuration_gemma3 import Gemma3Config +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, +) + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +gemma3_model_name = "google/gemma-3-1b-it" +gemma3_model = ( + AutoModelForCausalLM.from_pretrained( + gemma3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +GEMMA3_CONFIG = gemma3_model.config + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_gemma3_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # gemma3 + hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 5, 256), dtype=DTYPE).cuda(), + torch.randn((1, 5, 256), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_gemma3_attention_with_static_cache(args): + + import static_cache_v2 + + DTYPE = torch.float32 + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[ + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 256), dtype=DTYPE).cuda(), + torch.randn((1, 1, 256), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_gemma3_decoder(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + model = gemma3_model.model.layers[0].to(DTYPE) + # model.self_attn.is_sliding = False + + # gemma3 + hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda() + position_embeddings_global = ( + torch.randn((1, 6, 256), dtype=DTYPE).cuda(), + torch.randn((1, 6, 256), dtype=DTYPE).cuda(), + ) + position_embeddings_local = ( + torch.randn((1, 6, 256), dtype=DTYPE).cuda(), + torch.randn((1, 6, 256), dtype=DTYPE).cuda(), + ) + + pyt_output = model( + hidden_states, position_embeddings_global, position_embeddings_local + ) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ( + {1: seq_len}, + ({1: seq_len}, {1: seq_len}), + ({1: seq_len}, {1: seq_len}), + ) + ep = torch.export.export( + model, + (hidden_states, position_embeddings_global, position_embeddings_local), + dynamic_shapes=dynamic_shapes, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[ + hidden_states, + position_embeddings_global, + position_embeddings_local, + ], + enabled_precisions={torch.float32}, + debug=args.debug, + ) + trt_output = trt_model( + hidden_states, position_embeddings_global, position_embeddings_local + ) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + # breakpoint() + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + + +def test_gemma3_decoder_with_static_cache(args): + + class Gemma3DecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = GEMMA3_CONFIG + self.decoder = Gemma3DecoderLayer(config=self.config, layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE)) + + import static_cache_v2 + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings_global = ( + torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), + ) + position_embeddings_local = ( + torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), + torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model( + hidden_states, position_embeddings_global, position_embeddings_local + ) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[ + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 256), dtype=DTYPE).cuda(), + torch.randn((1, 1, 256), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_gemma3_attention(args) + # test_gemma3_attention_with_static_cache(args) + test_gemma3_decoder(args) + # test_gemma3_decoder_with_static_cache(args) diff --git a/tools/llm/test_llama_components.py b/tools/llm/test_llama_components.py new file mode 100644 index 0000000000..ef7e59cd72 --- /dev/null +++ b/tools/llm/test_llama_components.py @@ -0,0 +1,603 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +# llama2_model_name = "meta-llama/Llama-2-7b-hf" +llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" +llama_model = ( + AutoModelForCausalLM.from_pretrained( + llama3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +LLAMA_CONFIG = llama_model.config + + +def test_llama_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + model = llama_model.model.layers[0].self_attn.to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + from torch.export._trace import _export + + # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False) + ep = _export( + model, + args=(hidden_states, position_embeddings, None), + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + if isinstance(pyt_output, tuple): + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_llama_attention_with_static_cache(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention(config=self.config, layer_idx=0) + + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn( + hidden_states, position_embeddings, None + ) + return attn_output + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[ + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_decoder(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions=enabled_precisions, + debug=args.debug, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + ) + trt_output = trt_model(hidden_states, position_embeddings) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + + +def test_llama_decoder_with_static_cache(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes + ) + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[ + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"position_ids": position_ids} + from torch.export._trace import _export + + ep = _export( + model, + args=(input_ids,), + kwargs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + trt_output = trt_model(input_ids, position_ids) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + breakpoint() + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_llama_model_with_static_cache(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(input_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids} + ep = torch.export.export( + model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes + ) + + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model( + input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal + ) + pyt_output = pyt_output.last_hidden_state + print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda() + position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda() + + # Concatenate the current hidden_states with the previous ones + input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1) + position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1) + is_causal = False + kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full} + out_no_cache = model(**kwarg_inputs) + + out_trt, key_cache, value_cache = trt_model( + input_ids_curr, + position_ids_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache.last_hidden_state[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + input_ids = input_ids_full + position_ids = position_ids_full + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", type=str, default="FP16", help="Precision (default: FP16)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_llama_attention(args) + # test_llama_decoder(args) + test_llama_model(args) + # test_llama_attention_with_static_cache(args) + # test_llama_decoder_with_static_cache(args) + # test_llama_model_with_static_cache(args) diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py new file mode 100644 index 0000000000..60482bf22d --- /dev/null +++ b/tools/llm/test_qwen2.5_components.py @@ -0,0 +1,193 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct" +qwen2_5_model = ( + AutoModelForCausalLM.from_pretrained( + qwen2_5_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +QWEN_CONFIG = qwen2_5_model.config + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_qwen_apply_rotary_pos_emb(args): + class QwenApplyRotaryPosEmb(nn.Module): + def __init__(self): + super().__init__() + + def rotate_half(self, x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, q, k, cos, sin, unsqueeze_dim=1): + return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE) + # Shapes for Qwen 2.5 + q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + + pyt_output = model(q, k, cos, sin) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[q, k, cos, sin], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(q, k, cos, sin) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_apply_rotary_pos_emb(args) + test_qwen_attention(args) diff --git a/tools/llm/test_qwen3.py b/tools/llm/test_qwen3.py new file mode 100644 index 0000000000..e46f17050a --- /dev/null +++ b/tools/llm/test_qwen3.py @@ -0,0 +1,223 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen3_model_name = "Qwen/Qwen3-0.6B" +qwen3_model = ( + AutoModelForCausalLM.from_pretrained( + qwen3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +QWEN_CONFIG = qwen3_model.config + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen3_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_qwen3_decoder(args): + + class QwenDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = QWEN_CONFIG + self.model = model + + def forward(self, hidden_states, position_ids, position_embeddings): + return self.model( + hidden_states, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = QwenDecoderLayerBlock(qwen3_model.model.layers[0].to(DTYPE)) + # qwen3 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_embeddings = ( + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_ids, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, + (hidden_states, position_ids, position_embeddings), + dynamic_shapes=dynamic_shapes, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_ids, position_embeddings], + enabled_precisions={torch.float32}, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_ids, position_embeddings) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + + +def test_qwen3_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = qwen3_model.model.to(DTYPE) + # qwen3 + input_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_ids = ( + torch.arange(input_ids.shape[1], dtype=torch.int64).cuda().unsqueeze(0) + ) + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + ep = torch.export.export( + model, (input_ids, position_ids), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions={torch.float32}, + use_python_runtime=True, + disable_tf32=True, + debug=args.debug, + ) + # breakpoint() + trt_output = trt_model(input_ids, position_ids) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}" + ) + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[2] - trt_output[2]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP32", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_attention(args) + # test_qwen3_decoder(args) + test_qwen3_model(args) diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py new file mode 100644 index 0000000000..af72f93b67 --- /dev/null +++ b/tools/llm/test_static_cache.py @@ -0,0 +1,468 @@ +import argparse +import os +import sys +from contextlib import nullcontext + +import torch +import torch.nn as nn +import torch_tensorrt +from torch.export import export +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import register_sdpa + +ATOL = 1e-5 +RTOL = 1e-5 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class DynamicCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, flag): + def true_fn(q, k, v, k1, v1): + k_new = torch.cat((k, k1), dim=2) + v_new = torch.cat((v, v1), dim=2) + return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) + + def false_fn(q, k, v, k1, v1): + return torch._C._nn.scaled_dot_product_attention(q, k, v) + + out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) + + return 2 * out + + +class ModelNoCache(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch._C._nn.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=True + ) + + +class StaticCacheModel(nn.Module): + def __init__(self): + super().__init__() + + # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + # new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + # new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + # out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + + # return out, new_key_cache, new_value_cache + + def forward( + self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ): + concat_keys = torch.cat( + (key_cache[:, :, :start_idx, :], k), dim=2 + ) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ] + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat( + (concat_values, value_cache[:, :, end_idx:, :]), dim=2 + ) + out = torch._C._nn.scaled_dot_product_attention( + q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal + ) + + return out, new_key_cache, new_value_cache + + +def eager_sdpa( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + """ + Eager implementation of SDPA + """ + import math + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_no_cache_model_with_torch_tensorrt(args): + """ + Test the no cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + # q = torch.randn(1, 32, 6, 64).cuda() + # k = torch.randn(1, 32, 6, 64).cuda() + # v = torch.randn(1, 32, 6, 64).cuda() + q = torch.load("query.pt") + k = torch.load("key.pt") + v = torch.load("value.pt") + out_no_cache = model_no_cache(q, k, v) + out_eager = eager_sdpa(q, k, v, is_causal=True) + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + # Export the model + exported_program = torch.export.export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: q_seq_len}, {2: q_seq_len}), + strict=False, + ) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + min_block_size=1, + ) + out_trt = trt_model(q, k, v) + + print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager") + print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt") + print_diff(out_eager, out_trt, "out_eager vs out_trt") + breakpoint() + + +def test_static_cache_model(args): + """ + Test the static cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Test Prefill + start_idx = 0 + end_idx = 2048 + out_no_cache = model_no_cache(q, k, v) + out_static_cache, new_key_cache, new_value_cache = model_static_cache( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ) + assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_static_cache, new_key_cache, new_value_cache = model_static_cache( + q_curr, + k_curr, + v_curr, + new_key_cache, + new_value_cache, + start_idx, + end_idx, + is_causal=False, + ) + + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL + ) + q = q_full + k = k_full + v = v_full + print("============== test_static_cache passed ==============") + + +def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args): + """ + Transform the graph module by adding key and value cache to the graph + """ + gm = exported_program.module() + # Post lower the model + settings = torch_tensorrt.dynamo.conversion.CompilationSettings( + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions(get_decompositions(False)) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + return gm + + +def test_static_cache_lowering(args): + """ + Test static cache lowering pass applied to the model with no cache and run the graph module + and compare the output with the model with no cache + """ + import static_cache2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}), + strict=False, + ) + + gm = transform_gm_with_kv_cache(exported_program, args) + + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + out_no_cache = model_no_cache(q, k, v) + out_pyt_cache, key_cache, value_cache = gm( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal + ) + assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + is_causal = False + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_pyt_static_cache, key_cache, value_cache = gm( + q_curr, + k_curr, + v_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL + ) + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_lowering passed ==============") + + +def test_static_cache_export(args): + """ + Test the static cache model export + """ + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + # Export the model + seq_len = torch.export.Dim("seq_len", min=2, max=2048) + seq_len_dyn_dim = {2: seq_len} + exported_program = export( + model_static_cache, + args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), + dynamic_shapes=( + seq_len_dyn_dim, + seq_len_dyn_dim, + seq_len_dyn_dim, + None, + None, + torch.export.Dim.DYNAMIC, + torch.export.Dim.DYNAMIC, + None, + ), + strict=False, + ) + + +def test_static_cache_with_torch_tensorrt(args): + """ + Test the static cache model with torch_tensorrt + """ + import static_cache_v2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}), + strict=False, + ) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) + out_no_cache = model_no_cache(q, k, v) + out_trt, trt_key_cache, trt_value_cache = trt_model( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal + ) + + assert torch.allclose( + out_no_cache, out_trt, atol=ATOL, rtol=RTOL + ), "Prefill TRT logits don't match" + assert torch.allclose( + trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL + ), "Prefill TRT key cache don't match" + assert torch.allclose( + trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL + ), "Prefill TRT value cache don't match" + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + is_causal = True + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_trt, trt_key_cache, trt_value_cache = trt_model( + q_curr, + k_curr, + v_curr, + trt_key_cache, + trt_value_cache, + start_idx, + end_idx, + is_causal, + ) + # breakpoint() + # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") + # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") + # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}") + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL + ), f"Generate TRT logits don't match for idx {start_idx}" + assert torch.allclose( + trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL + ), f"Generate TRT key cache don't match for idx {start_idx}" + assert torch.allclose( + trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL + ), f"Generate TRT value cache don't match for idx {start_idx}" + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_with_torch_tensorrt passed ==============") + + +def main(): + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_no_cache_model_with_torch_tensorrt(args) + # test_static_cache_model(args) + # test_static_cache_lowering(args) + test_static_cache_with_torch_tensorrt(args) + + +if __name__ == "__main__": + main() diff --git a/tools/llm/utils.py b/tools/llm/utils.py new file mode 100644 index 0000000000..5ccb9d0e55 --- /dev/null +++ b/tools/llm/utils.py @@ -0,0 +1,244 @@ +import copy +import timeit + +import numpy as np +import torch +from transformers import StoppingCriteriaList +from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, +) + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-2] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append( + torch.zeros( + input.meta["val"].shape, + dtype=input.meta["val"].dtype, + device=torch.device("cuda:0"), + ) + ) + + return tuple(zeroed_kv_cache_inputs) + + +def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-1] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append( + torch.zeros( + input.meta["val"].shape, + dtype=input.meta["val"].dtype, + device=torch.device("cuda:0"), + ) + ) + + return tuple(zeroed_kv_cache_inputs) + + +def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + isl = input_seq.shape[1] + osl = max_output_seq_length - isl + + num_tokens_generated = 0 + while num_tokens_generated < osl: + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + outputs = model(input_seq, position_ids=position_ids) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + num_tokens_generated += 1 + # TODO: Handle batch in this check + if not benchmark and stopping_criteria(input_seq, logits).item(): + break + + return input_seq + + +def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with static KV cache. + """ + start_idx = 0 + end_idx = input_seq.shape[1] + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL + num_tokens_generated = 0 + kv_cache = get_zeroed_static_cache_inputs(model) + while end_idx < max_output_seq_length: + position_ids = ( + torch.tensor([[start_idx]], dtype=torch.int64).cuda() + if input_seq.shape[1] == 1 + else position_ids + ) + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + start_idx = end_idx + end_idx = start_idx + 1 + return output_seq + + +def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with dynamic KV cache. + """ + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + num_output_tokens = max_output_seq_length - input_seq.shape[1] + num_tokens_generated = 0 + kv_cache = get_zeroed_dynamic_cache_inputs(model) + last_position_id = position_ids[-1, -1].item() + breakpoint() + while num_tokens_generated < num_output_tokens: + is_generate = False if input_seq.shape[1] > 1 else True + position_ids = ( + torch.tensor([[last_position_id + 1]], dtype=torch.int64).cuda() + if input_seq.shape[1] == 1 + else position_ids + ) + input_signature = (input_seq, position_ids, *kv_cache, is_generate) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + last_position_id += 1 + return output_seq + + +def time_generate( + generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn(model, inputs, output_seq_length, eos_token_id) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings + + +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): + """ + Records different timing stats and adds it to the result + """ + times = np.array(timings) + speeds = batch_size / times + time_mean = np.mean(times).item() + time_med = np.median(times).item() + time_99th = np.percentile(times, 99).item() + time_std = np.std(times, ddof=0).item() + speed_mean = np.mean(speeds).item() + speed_med = np.median(speeds).item() + + stats = { + "Backend": backend, + "Precision": precision, + "Batch size": batch_size, + "Median(FPS)": speed_med, + "Mean(FPS)": speed_mean, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, + } + return stats From 723d1b28bf6d2992b33af660a22be13d6ee70b67 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Jun 2025 22:55:09 +0000 Subject: [PATCH 24/30] chore: remove files not considered for release Signed-off-by: Dheeraj Peri --- .../lowering/passes/constant_folding.py | 15 +- .../lower_scaled_dot_product_attention.py | 172 ---- .../dynamo/runtime/_PythonCUDAGraphModule.py | 771 ------------------ tools/llm/dynamic_cache.py | 215 ----- tools/llm/llm_pyt_benchmark.py | 83 -- tools/llm/run_vlm.py | 333 -------- tools/llm/test_gemma.py | 389 --------- tools/llm/test_qwen3.py | 223 ----- 8 files changed, 14 insertions(+), 2187 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py delete mode 100644 py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py delete mode 100644 tools/llm/dynamic_cache.py delete mode 100644 tools/llm/llm_pyt_benchmark.py delete mode 100644 tools/llm/run_vlm.py delete mode 100644 tools/llm/test_gemma.py delete mode 100644 tools/llm/test_qwen3.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 172d902a40..9d894651ad 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -55,6 +55,7 @@ def constant_fold( del cf logger.debug(f"Graph after constant folding:\n{gm.graph}") + return gm @@ -98,6 +99,18 @@ class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + # Set of known quantization ops to be excluded from constant folding. + # Currently, we exclude all quantization ops coming from modelopt library. + quantization_ops = {} + try: + # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered + import modelopt.torch.quantization as mtq + + assert torch.ops.tensorrt.quantize_op.default + quantization_ops.add(torch.ops.tensorrt.quantize_op.default) + except Exception as e: + pass + if quantization_ops and node.target in quantization_ops: + return True return False diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py deleted file mode 100644 index 89558acade..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ /dev/null @@ -1,172 +0,0 @@ -import copy -import logging -import operator -from typing import Callable, Sequence, Tuple - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) -REPLACEABLE_ATEN_OPS = { - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, -} - - -def lower_scaled_dot_product_attention( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace specific versions of scaled_dot_product_attention with an equivalent - implementation which can be easily converted to TRT - """ - original_fns, replacement = scaled_dot_product_attention_replacement() - replaced_nodes = [] - sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default] - breakpoint() - # For each original function, search for it in the graph and replace - for original in original_fns: - replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( - gm, - original, - replacement, - ignore_literals=True, - ) - breakpoint() - if replaced_nodes: - # Repair instances which use the kwargs field (specifically the "scale" kwarg) - # Also repair instances which specified the is_causal or attn_bias fields - for match in replaced_nodes: - attention_node_replaced = None - # Seek the attention operator being replaced - for node in match.nodes_map: - if node.target in REPLACEABLE_ATEN_OPS: - attention_node_replaced = match.nodes_map[node] - break - - assert attention_node_replaced is not None - assert len(match.replacements) == 1 - - new_attention_node = match.replacements[0] - - assert ( - new_attention_node.target - == torch.nn.functional.scaled_dot_product_attention - ) - - # Copy the metadata of the replaced attention node to the new node - # TODO: Investigate why there are multiple FakeTensors in the metadata. - # We only use the first one as it contains the output shape information for this node. - if "val" in attention_node_replaced.meta: - new_attention_node.meta["val"] = copy.copy( - attention_node_replaced.meta["val"][0] - ) - - # If the attention operator had keyword-args, copy them to the new node - if attention_node_replaced.kwargs: - new_attention_node.kwargs = {**attention_node_replaced.kwargs} - - # Set default args in new node: - # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False - breakpoint() - new_attention_node.args = new_attention_node.args + (None, 0.0, False) - breakpoint() - # The `is_causal` argument was specified - if ( - ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ) - and args_bounds_check(attention_node_replaced.args, 4, False) - ) or ( - ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ) - and args_bounds_check(attention_node_replaced.args, 6, False) - ): - new_attention_node.args = ( - new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] - ) - - # The `attn_bias` argument was specified - if ( - attention_node_replaced.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ) and args_bounds_check(attention_node_replaced.args, 3) is not None: - new_attention_node.args = ( - new_attention_node.args[:3] - + attention_node_replaced.args[3] - + new_attention_node.args[4:] - ) - - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") - - return gm - - -def scaled_dot_product_attention_replacement() -> Tuple[ - Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -]: - """Constructs the original and replacement functions for efficient attention""" - - # Efficient Attention original graph - def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, - k, - v, - None, - False, - ) - out = operator.getitem(outputs, 0) - return out - - # Flash Attention original graph - def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - ) - out = operator.getitem(outputs, 0) - return out - - # Efficient Attention w/Scale original graph - def efficient_scale( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, - k, - v, - None, - False, - scale=1.0, - ) - out = operator.getitem(outputs, 0) - return out - - # Flash Attention w/Scale original graph - def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - scale=1.0, - ) - out = operator.getitem(outputs, 0) - return out - - # Replacement graph consists of the functional version of scaled_dot_product_attention - def replacement( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - return torch.nn.functional.scaled_dot_product_attention(query, key, value) - - return (efficient, flash, efficient_scale, flash_scale), replacement \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py deleted file mode 100644 index 9aac192316..0000000000 --- a/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py +++ /dev/null @@ -1,771 +0,0 @@ -from __future__ import annotations - -import logging -from contextlib import nullcontext -from tempfile import tempdir -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import tensorrt as trt -import torch -import torch_tensorrt -from torch.nn import Module -from torch_tensorrt._Device import Device -from torch_tensorrt._enums import Platform, dtype -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from torch_tensorrt.logging import TRT_LOGGER -from torch_tensorrt.runtime._utils import ( - _is_switch_required, - _select_rt_device, - multi_gpu_device_check, -) - -logger = logging.getLogger(__name__) - - -class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] - def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: - trt.IOutputAllocator.__init__(self) - self.buffers: Dict[str, torch.Tensor] = {} - self.shapes: Dict[str, Tuple[int, ...]] = {} - self.dtypes: Dict[str, torch.dtype] = output_dtypes - - def reallocate_output_async( - self, - tensor_name: str, - memory: int, - size: int, - alignment: int, - stream: torch.cuda.Stream, - ) -> Any: - shape = (size,) - if tensor_name not in self.buffers: - self.buffers[tensor_name] = torch.empty( - shape, - dtype=self.dtypes[tensor_name], - device=torch.cuda.current_device(), - ) - else: - if self.buffers[tensor_name].shape != shape: - self.buffers[tensor_name] = torch.empty( - shape, - dtype=self.dtypes[tensor_name], - device=torch.cuda.current_device(), - ) - return self.buffers[tensor_name].data_ptr() - - def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: - self.shapes[tensor_name] = tuple(shape) - - -class TorchTRTRuntimeStates: - def __init__(self, new_cudagraphs: bool): - # Indicates whether CUDAGraphs were enabled in the previous execute_engine - self.old_cudagraphs = new_cudagraphs - # Indicates whether pre-allocated output was enabled in the previous execute_engine - self.old_pre_allocated_outputs = False - # Indicates whether context has changed - self.context_changed = False - - def set_runtime_states( - self, - new_cudagraphs: bool, - new_pre_allocated_output: bool, - shape_changed: bool, - ) -> Tuple[bool, bool, bool]: - # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs - # based on the current and previous states, as well as input shape has changed - need_cudagraphs_record = False - can_use_pre_allocated_outputs = False - need_cudagraphs_reset = False - - # CUDA Graph recording is needed if CUDA graphs is enabled and: - # - CUDA graphs were previously disabled - # - or the shape has changed - # - or the execution context has changed (e.g., weight streaming) - if new_cudagraphs and ( - not self.old_cudagraphs or shape_changed or self.context_changed - ): - need_cudagraphs_record = True - - # Pre-allocated output can be used when previous and current state are true without shape change - if ( - self.old_pre_allocated_outputs - and new_pre_allocated_output - and (not shape_changed) - ): - can_use_pre_allocated_outputs = True - - if not new_cudagraphs or shape_changed or self.context_changed: - need_cudagraphs_reset = True - - self.old_cudagraphs = new_cudagraphs - self.old_pre_allocated_outputs = new_pre_allocated_output - # reset flag - self.context_changed = False - - return ( - need_cudagraphs_record, - can_use_pre_allocated_outputs, - need_cudagraphs_reset, - ) - - -class PythonTorchTensorRTModule(Module): # type: ignore[misc] - """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. - - This module is backed by the Torch-TensorRT runtime and is only compatible with - FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment. - """ - - def __init__( - self, - serialized_engine: Optional[bytes] = None, - input_binding_names: Optional[List[str]] = None, - output_binding_names: Optional[List[str]] = None, - *, - name: str = "", - settings: CompilationSettings = CompilationSettings(), - weight_name_map: Optional[dict[Any, Any]] = None, - requires_output_allocator: bool = False, - ): - """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs - a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine - - Arguments: - serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray - input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules - output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned - - Keyword Arguments: - name (str): Name for module - settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed - weight_name_map (dict): Mapping of engine weight name to state_dict weight name - requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) - - Example: - - .. code-block:: py - - trt_module = PythonTorchTensorRTModule( - engine_str, - input_binding_names=["x"], - output_binding_names=["output"], - name="my_module", - settings=CompilationSettings(device=torch.cuda.current_device) - ) - - """ - self.context: Any - super(PythonTorchTensorRTModule, self).__init__() - self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) - - # Run multi-gpu device check to validate engine instantiation - multi_gpu_device_check() - - self.name = name - self._input_buffers: Dict[str, List[torch.Tensor]] = {} - self._output_buffers: Dict[str, List[torch.Tensor]] = {} - self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - self._caller_stream: Optional[torch.cuda.Stream] = None - self._engine_stream: Optional[torch.cuda.Stream] = None - - # TODO: Make the below a Dictionary {shape: cudagraph} - self.shape_key_to_cudagraph: Dict[str, torch.cuda.CUDAGraph] = {} - - # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 - # Unused currently - to be used by Dynamic Shape support implementation - self.memory_pool = None - - self.serialized_engine = serialized_engine - self.input_names = ( - input_binding_names if input_binding_names is not None else [] - ) - self.output_names = ( - output_binding_names if output_binding_names is not None else [] - ) - self.initialized = False - self.target_device_id = ( - settings.device.gpu_id - if settings.device is not None - else Device._current_device().gpu_id - ) - self.target_device_properties = torch.cuda.get_device_properties( - self.target_device_id - ) - self.profiling_enabled = settings.debug if settings.debug is not None else False - self.settings = settings - self.engine = None - self.weight_name_map = weight_name_map - self.target_platform = Platform.current_platform() - self.runtime_states = TorchTRTRuntimeStates( - torch_tensorrt.runtime.get_cudagraphs_mode() - ) - - self.cudagraphs_enabled = False - self.pre_allocated_outputs: List[torch.Tensor] = [] - self.use_pre_allocated_outputs = False - - self.requires_output_allocator = requires_output_allocator - self.output_allocator: Optional[DynamicOutputAllocator] = None - self.use_output_allocator_outputs = False - - if self.serialized_engine is not None and not self.settings.lazy_engine_init: - self.setup_engine() - - def get_streamable_device_memory_budget(self) -> Any: - return self.engine.streamable_weights_size - - def get_automatic_device_memory_budget(self) -> Any: - return self.engine.get_weight_streaming_automatic_budget() - - def get_device_memory_budget(self) -> Any: - return self.engine.weight_streaming_budget_v2 - - def set_device_memory_budget(self, budget_bytes: int) -> int: - # Recreating the context because weight streaming budget cannot be modified while there are active context. - if self.context is not None: - del self.context - budget_bytes = self._set_device_memory_budget(budget_bytes) - self.context = self.engine.create_execution_context() - self.runtime_states.context_changed = True - return budget_bytes - - def _set_device_memory_budget(self, budget_bytes: int) -> int: - # Disable weight streaming for invalid budget size - if budget_bytes < 0: - budget_bytes = self.get_streamable_device_memory_budget() - self.engine.weight_streaming_budget_v2 = budget_bytes - if self.engine.weight_streaming_budget_v2 != budget_bytes: - logger.error(f"Failed to set weight streaming budget to {budget_bytes}") - budget_bytes = self.engine.weight_streaming_budget_v2 - if self.get_streamable_device_memory_budget() == budget_bytes: - logger.warning("Weight streaming is disabled") - - return budget_bytes - - def set_default_device_memory_budget(self) -> int: - budget_bytes = self.get_automatic_device_memory_budget() - # Set automatic weight streaming budget as default when context is created - logger.debug(f"Weight streaming budget set to {budget_bytes}B") - return self._set_device_memory_budget(budget_bytes) - - def setup_engine(self) -> None: - assert ( - self.target_platform == Platform.current_platform() - ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" - - self.initialized = True - runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) - if self.settings.enable_weight_streaming: - self.set_default_device_memory_budget() - self.context = self.engine.create_execution_context() - assert self.engine.num_io_tensors == ( - len(self.input_names) + len(self.output_names) - ) - - self.input_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(input_name)) - for input_name in self.input_names - ] - self.input_shapes = [ - self.engine.get_tensor_shape(input_name) for input_name in self.input_names - ] - self.output_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) - for output_name in self.output_names - ] - self.output_shapes = [ - self.engine.get_tensor_shape(output_name) - for output_name in self.output_names - ] - - if self.requires_output_allocator: - self.create_output_allocator() - - if torch_tensorrt.runtime.get_cudagraphs_mode(): - self.cudagraph = torch.cuda.CUDAGraph() - - def _check_initialized(self) -> None: - if not self.initialized: - raise RuntimeError("PythonTorchTensorRTModule is not initialized.") - - def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: - state_dict[prefix + "engine"] = self.serialized_engine - state_dict[prefix + "input_names"] = self.input_names - state_dict[prefix + "output_names"] = self.output_names - state_dict[prefix + "platform"] = self.target_platform - - def _load_from_state_dict( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Any, - strict: Any, - missing_keys: Any, - unexpected_keys: Any, - error_msgs: Any, - ) -> None: - self.serialized_engine = state_dict[prefix + "engine"] - self.input_names = state_dict[prefix + "input_names"] - self.output_names = state_dict[prefix + "output_names"] - self.target_platform = state_dict[prefix + "platform"] - - # Run multi-gpu device check to validate engine instantiation - multi_gpu_device_check() - self.setup_engine() - - def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() - state.pop("engine", None) - state.pop("context", None) - return state - - def __setstate__(self, state: Dict[str, Any]) -> None: - self.__dict__.update(state) - self.setup_engine() - - def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - result.__setstate__(self.__getstate__()) - return result - - def _reset_captured_graph(self, inputs_shape_key: str = None) -> None: - if inputs_shape_key in self.shape_key_to_cudagraph: - self.shape_key_to_cudagraph[inputs_shape_key].reset() - self.shape_key_to_cudagraph.pop(inputs_shape_key) - - def __del__(self) -> None: - self._reset_captured_graph() - - def setup_input_tensors( - self, - contiguous_inputs: List[torch.Tensor], - cudagraphs_enabled: bool, - need_cudagraphs_record: bool, - inputs_shape_key: str = None, - ) -> None: - for i, input_name in enumerate(self.input_names): - if not contiguous_inputs[i].is_cuda: - logger.warning( - f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " - "This tensor is being moved by the runtime but for performance considerations, " - "ensure your inputs are all on GPU and open an issue here " - "(https://github.com/pytorch/TensorRT/issues) if this warning persists." - ) - contiguous_inputs = ( - contiguous_inputs[:i] - + [contiguous_inputs[i].cuda()] - + contiguous_inputs[i + 1 :] - ) - - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - - is_shape_tensor_input = self.engine.is_shape_inference_io(input_name) - if need_cudagraphs_record: - # If cudagraphs is enabled, this memory is reserved for future cudagraph runs - # Clone is required to avoid re-using user-provided GPU memory - if is_shape_tensor_input: - self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone() - else: - self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone() - - # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers - # as per TensorRT requirements - if is_shape_tensor_input: - # Shape tensor inputs are casted to int64 explicitly - # Currently Torch CPU pointers are not working; numpy pointers are used instead - # to refer to underlying memory - inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64) - inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() - # if cudagraphs_enabled: - # self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu) - # self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data) - # else: - self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data) - else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) - if cudagraphs_enabled: - self._input_buffers[inputs_shape_key][i].copy_(contiguous_inputs[i]) - self.context.set_tensor_address( - input_name, self._input_buffers[inputs_shape_key][i].data_ptr() - ) - else: - self.context.set_tensor_address( - input_name, contiguous_inputs[i].data_ptr() - ) - - def create_output_tensors(self) -> List[torch.Tensor]: - # create output tensors - outputs: List[torch.Tensor] = [] - - for o, _ in enumerate(self.output_names): - output = torch.empty( - size=self.output_shapes[o], - dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), - ) - outputs.append(output) - return outputs - - def set_pre_allocated_outputs(self, enable: bool) -> None: - self.use_pre_allocated_outputs = enable - - def set_use_output_allocator(self, enable: bool) -> None: - self.use_output_allocator_outputs = enable - - def create_output_allocator(self) -> None: - if self.output_allocator is None: - output_dtypes_dict = {} - for o, output_name in enumerate(self.output_names): - output_dtypes_dict[output_name] = self.output_dtypes[o] - self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) - - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - - def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: - # print(f"**************** first key cache shape: {inputs[1].shape}") - shape_changed, inputs_shape_key = self.validate_input_shapes(inputs) - ( - need_cudagraphs_record, - can_use_pre_allocated_outputs, - need_cudagraphs_reset, - ) = self.runtime_states.set_runtime_states( - self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed - ) - - if need_cudagraphs_reset: - self._reset_captured_graph(inputs_shape_key) - - if need_cudagraphs_record: - self._input_buffers[inputs_shape_key] = [None] * len(self.input_names) - self._output_buffers[inputs_shape_key] = [None] * len(self.output_names) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) - if self.profiling_enabled - else nullcontext() - ): - assert len(contiguous_inputs) == len( - self.input_names - ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." - - self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key - ) - - if shape_changed: - # Check if input shapes can be inferred. - uninferred_input_names = self.context.infer_shapes() - if uninferred_input_names: - logger.warning( - f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ - This could happen if the input tensor addresses/shapes haven't been configured correctly" - ) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" - ) - if self.profiling_enabled - else nullcontext() - ): - if can_use_pre_allocated_outputs: - outputs = self.pre_allocated_outputs - else: - self.output_shapes = [ - tuple(self.context.get_tensor_shape(output_name)) - for output_name in self.output_names - ] - if DYNAMIC_DIM in self.output_shapes: - raise ValueError( - "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." - ) - outputs = self.create_output_tensors() - - for o, output_name in enumerate(self.output_names): - if need_cudagraphs_record: - self._output_buffers[inputs_shape_key][o] = outputs[o].clone() - - if self.cudagraphs_enabled: - self.context.set_tensor_address( - output_name, self._output_buffers[inputs_shape_key][o].data_ptr() - ) - else: - self.context.set_tensor_address( - output_name, outputs[o].data_ptr() - ) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:TensorRTRuntime" - ) - if self.profiling_enabled - else nullcontext() - ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) - - with torch.cuda.stream(self._engine_stream): - if self.cudagraphs_enabled: - if need_cudagraphs_record: - - self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph() - - if self.profiling_enabled: - self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode() - - with torch.cuda.graph( - self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream - ) - - if self.profiling_enabled: - import tempfile - - with tempfile.TemporaryDirectory() as tmpdir: - self.shape_key_to_cudagraph[inputs_shape_key].debug_dump( - f"{tempdir}/{self.name}_cudagraph.dot" - ) - - self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore - - else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) - - self._caller_stream.wait_stream(self._engine_stream) - - if self.use_pre_allocated_outputs: - self.pre_allocated_outputs = self.create_output_tensors() - - if self.cudagraphs_enabled: - for idx, o in enumerate(outputs): - o.copy_(self._output_buffers[inputs_shape_key][idx]) - - if len(outputs) == 1: - return outputs[0] - - return outputs - - def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: - assert ( - not torch_tensorrt.runtime.get_cudagraphs_mode() - ), "CUDA Graphs are not compatible with OutputAllocator." - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) - if self.profiling_enabled - else nullcontext() - ): - assert len(contiguous_inputs) == len( - self.input_names - ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." - - self.setup_input_tensors(contiguous_inputs, False, False) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:SetupOutputAllocator" - ) - if self.profiling_enabled - else nullcontext() - ): - self.create_output_allocator() - # need to set output allocator every run - for output_name in self.output_names: - if not self.context.set_output_allocator( - output_name, self.output_allocator - ): - raise RuntimeError( - f"Failed to set output allocator for {output_name}" - ) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:TensorRTRuntime" - ) - if self.profiling_enabled - else nullcontext() - ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) - - with torch.cuda.stream(self._engine_stream): - self.context.execute_async_v3( - self._engine_stream.cuda_stream - ) # The OutputAllocator is called by execute_async_v3() - - self._caller_stream.wait_stream(self._engine_stream) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" - ) - if self.profiling_enabled - else nullcontext() - ): - outputs = [] - assert self.output_allocator is not None - for o, output_name in enumerate(self.output_names): - shape = self.output_allocator.shapes.get(output_name, None) - dtype = self.output_dtypes[o] - output = ( - self.output_allocator.buffers.get(output_name, None) - .clone() - .detach() - ) - prod = int(torch.prod(torch.tensor(shape))) - # When using the OutputAllocator, the allocated buffer might be larger than the size of the output, - # so we need to reshape the buffer to the output shape - output = output.reshape(-1).view(dtype)[:prod].reshape(shape) - outputs.append(output) - - if len(outputs) == 1: - return outputs[0] - - return outputs - - self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - - # Run forward function - contiguous_inputs: List[torch.Tensor] = [ - (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) - for i in inputs - ] - with ( - torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") - if self.profiling_enabled - else nullcontext() - ): - self._check_initialized() - - # If in safe mode, check at each iteration for whether a switch is required - if ( - torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE - ): - curr_device_id = torch.cuda.current_device() - curr_device_properties = torch.cuda.get_device_properties( - curr_device_id - ) - logger.debug(f"Current Device: cuda:{curr_device_id}") - - # If a switch is required, move all inputs to new device and set as active device - if _is_switch_required( - curr_device_id, - self.target_device_id, - curr_device_properties, - self.target_device_properties, - ): - device_id, _ = _select_rt_device( - curr_device_id, - self.target_device_id, - self.target_device_properties, - ) - - # Update current device - device = torch.device(device_id) - torch.cuda.set_device(device_id) - - contiguous_inputs = [ - tensor.to(device) for tensor in contiguous_inputs - ] - logger.warning(f"Moved all input Tensors to cuda:{device_id}") - - if self.requires_output_allocator: # engine requires OA - if self.cudagraphs_enabled: - raise RuntimeError( - "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." - ) - logger.debug("Using the dynamic allocator runtime mode.") - return run_output_allocator() - else: - if self.use_output_allocator_outputs: # users call OA context manager - if self.cudagraphs_enabled: - raise RuntimeError( - "Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two." - ) - logger.debug("Using the dynamic allocator runtime mode.") - return run_output_allocator() - else: - logger.debug( - f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." - ) - return run_standard_execution() - - def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: - """ - Enable TensorRT profiling. After calling this function, TensorRT will report - time spent on each layer in stdout for each forward run. - """ - self._check_initialized() - - if not self.context.profiler: - self.context.profiler = trt.Profiler() if profiler is None else profiler - - self.profiling_enabled = True - - def disable_profiling(self) -> None: - """ - Disable TensorRT profiling. - """ - self._check_initialized() - torch.cuda.synchronize() - del self.context - self.context = self.engine.create_execution_context() - self.profiling_enabled = False - - def get_layer_info(self) -> str: - """ - Get layer info of the engine. Only support for TRT > 8.2. - """ - inspector = self.engine.create_engine_inspector() - engine_json: str = inspector.get_engine_information( - trt.LayerInformationFormat.JSON - ) - return engine_json - - def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: - """ - Validates the input shapes of the forward function has changed - """ - # Representation of input shapes to a given model - # Shapes are concatenated as so: - # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [ - t if isinstance(t, torch.Tensor) else torch.tensor(t) - for t in inputs - ] - new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) - - # If the new shape key differs from the existing one, - # invalidate the old shape key and remove the CUDAGraph - if new_shape_key not in self.shape_key_to_cudagraph: - logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.") - # self.shape_key = new_shape_key - return True, new_shape_key - - return False, new_shape_key diff --git a/tools/llm/dynamic_cache.py b/tools/llm/dynamic_cache.py deleted file mode 100644 index b45ebb6d43..0000000000 --- a/tools/llm/dynamic_cache.py +++ /dev/null @@ -1,215 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -import torch.utils._pytree as pytree -from cache_utils import ( - add_graph_input, - create_random_output_tensors, - get_kv_nodes, - is_op, -) -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( - _aten_lowering_pass, -) -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -from torch_tensorrt.dynamo.utils import extract_var_range_info - -logger = logging.getLogger(__name__) - - -def add_kv_as_outputs(gm): - """ - Modifies the graph to add query, key, and value tensors as outputs. - - This function identifies all scaled dot-product attention (SDPA) operations - in the graph, creates copies of their query, key, and value inputs, and adds - these copies to the graph's outputs. This allows for accessing these tensors - externally, which is useful for operations like key-value caching. - - Args: - graph: The torch.fx.Graph to modify - - Returns: - None. The graph is modified in-place. - """ - # list of MHA kernels we would want to detect and replace - mha_ops = { - torch._C._nn.scaled_dot_product_attention, - } - - # Find all SDPA nodes in the graph - mha_nodes = [] - for node in gm.graph.nodes: - if is_op(node, mha_ops): - mha_nodes.append(node) - - # Iterate through each MHA node to extract shape information - for mha_node in mha_nodes: - if "val" in mha_node.meta and len(mha_node.args) >= 3: - # Get the input nodes (query, key, value) - q_node, k_node, v_node = mha_node.args[:3] - - # Add the copy nodes as outputs to the graph - output_node = next(node for node in gm.graph.nodes if node.op == "output") - - # Get the current output args (typically a tuple) - current_outputs = output_node.args[0] - - # If the current output is a tuple, extend it with our new outputs - if isinstance(current_outputs, tuple): - new_outputs = current_outputs + ((k_node, v_node),) - else: - # If there's only one output or it's not a tuple, create a new tuple - new_outputs = (current_outputs, (k_node, v_node)) - - gm.graph.output(new_outputs) - gm.graph.erase_node(output_node) - - return new_outputs - - -def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): - """ - Add key-value tensors and index parameters as inputs to the graph. - - Args: - gm: The GraphModule to modify - fixed_kv: Boolean indicating whether to use static tensors for KV cache - - Returns: - A tuple containing: - - List of (k_input, v_input) node pairs for each SDPA operation - - start_idx input node for slicing operations - - end_idx input node for slicing operations - """ - - def get_static_tensor(tensor: torch.Tensor): - key_shape = [] - for dim in tensor.shape: - if isinstance(dim, torch.SymInt): - min_max_opt = extract_var_range_info(dim) - key_shape.append(min_max_opt["max"]) - else: - key_shape.append(dim) - - static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) - return static_tensor - - keys_values = get_kv_nodes(gm) - - kv_inputs = [] - for idx, key_value in enumerate(keys_values): - k_val = key_value[0].meta["val"] - v_val = key_value[1].meta["val"] - if fixed_kv: - k_val = get_static_tensor(k_val) - v_val = get_static_tensor(v_val) - - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) - kv_inputs.append((k_input, v_input)) - - # Add is_generate as input - is_generate_input = add_graph_input(gm, "is_generate", True) - is_generate_input.meta["val"] = torch.tensor(True) - - return kv_inputs, is_generate_input - - -def insert_torch_cond_before_sdpa( - gm, - incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], - is_generate_input: torch.Tensor, -): - """ - Insert a torch.cond operation before each scaled_dot_product_attention operation. - - Args: - gm: The FX GraphModule to modify - - Returns: - The modified GraphModule - """ - # Find all nodes with scaled_dot_product_attention - sdpa_nodes = [] - for node in gm.graph.nodes: - if ( - node.op == "call_function" - and node.target == torch._C._nn.scaled_dot_product_attention - ): - sdpa_nodes.append(node) - - # For each SDPA node, insert a torch.cond operation before it - for idx, sdpa_node in enumerate(sdpa_nodes): - - with gm.graph.inserting_before(sdpa_node): - # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) - q_node, k_node, v_node = sdpa_node.args[:3] - incoming_key, incoming_value = incoming_keys_values[idx] - # Create nodes for concatenating k with incoming_key and v with incoming_value - concatenated_k_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=( - [incoming_key, k_node], - 2, - ), # Concatenate along sequence length dimension - kwargs={}, - ) - concatenated_v_node = gm.graph.create_node( - "call_function", - torch.ops.aten.cat.default, - args=( - [incoming_value, v_node], - 2, - ), # Concatenate along sequence length dimension - kwargs={}, - ) - - # Create the torch.cond node - cond_k_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(is_generate_input, concatenated_k_node, k_node), - ) - - cond_v_node = gm.graph.create_node( - "call_function", - torch.ops.higher_order.cond, - args=(is_generate_input, concatenated_v_node, v_node), - ) - - sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] - - return gm - - -@_aten_lowering_pass -def insert_dynamic_kv_cache( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Insert FlashInfer MHA + KV cache ops in the graph""" - """Perform insertion of kv-caches and attention kernel.""" - - # Add static key and value as inputs to the graph - kv_inputs, is_generate_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) - - # Call the function to add KV as outputs - logits_keys_values = add_kv_as_outputs(gm) - - # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases - gm = insert_torch_cond_before_sdpa(gm, kv_inputs, is_generate_input) - - gm = clean_up_graph_after_modifications(gm) - - new_output_tensors = create_random_output_tensors(logits_keys_values) - new_out_spec = pytree.tree_flatten(new_output_tensors)[1] - gm._out_spec = new_out_spec - - logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) - return gm diff --git a/tools/llm/llm_pyt_benchmark.py b/tools/llm/llm_pyt_benchmark.py deleted file mode 100644 index f3d68a951a..0000000000 --- a/tools/llm/llm_pyt_benchmark.py +++ /dev/null @@ -1,83 +0,0 @@ -import timeit - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -USE_CACHE = True -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -# MODEL_NAME = "Qwen/Qwen3-0.6B" -MAX_NEW_TOKENS = 128 - - -def main(): - # Initialize model and tokenizer - print("Loading model and tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, torch_dtype=torch.float16, use_cache=USE_CACHE, device_map="auto" - ) - # model.generation_config.cache_implementation = "static" - # model.forward = torch.compile(model.forward) - - # Prepare input prompt - word = "What" - # Tokenize the word - word_ids = tokenizer(word, return_tensors="pt").input_ids[ - 0 - ] # Get the first (and only) sequence - # Repeat the token 2048 times - input_ids = ( - word_ids.repeat(1024).unsqueeze(0).to(model.device) - ) # Add batch dimension and move to device - print(f"Input tensor shape: {input_ids.shape}") - - # # Warm-up pass - print("Running warm-up pass...") - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE, - ) - - # Benchmark loop - print("Running benchmark...") - num_iterations = 10 - total_time = 0 - timings = [] - - for i in range(num_iterations): - start_time = timeit.default_timer() - output_ids = model.generate( - input_ids, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - use_cache=USE_CACHE, - ) - end_time = timeit.default_timer() - generation_time = end_time - start_time - total_time += generation_time - timings.append(generation_time) - - # Decode and print first iteration output - # if i == 0: - # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) - # print("\nFirst generation output:") - # print(output_text) - - # Calculate and print statistics - average_time = total_time / num_iterations - print(f"\nPerformance Statistics:") - print( - f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds" - ) - print(f"Average tokens per second: {100/average_time:.2f}") - print("\nIndividual timings (ms):") - for i, t in enumerate(timings): - print(f"Iteration {i+1}: {t*1000:.2f}") - - -if __name__ == "__main__": - main() diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py deleted file mode 100644 index 470b0e6d99..0000000000 --- a/tools/llm/run_vlm.py +++ /dev/null @@ -1,333 +0,0 @@ -""" -.. _torch_export_gpt2: - -Compiling GPT2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. -""" - -import argparse -import copy -import os -import sys -import timeit -from contextlib import nullcontext - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import ( - export_llm, - generate, - generate_with_kv_cache, - recordStats, - time_generate, -) - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * - -DEVICE = torch.device("cuda:0") - - -def get_model(args): - with torch.no_grad(): - # Supported list of models: - # - meta-llama/Llama-3.2-1B-Instruct - # - meta-llama/Llama-3.2-3B-Instruct - # - meta-llama/Llama-3.1-8B-Instruct - # - Qwen/Qwen2.5-1.5B-Instruct - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=2, - ) - .eval() - .cuda() - ) - if args.precision == "FP16": - model = model.to(torch.float16) - elif args.precision == "BF16": - model = model.to(torch.bfloat16) - else: - model = model.to(torch.float32) - - return model - - -def compile_torchtrt(model, input_ids, args): - max_seq_len = input_ids.shape[1] + args.num_tokens - ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[input_ids, position_ids], - enabled_precisions=enabled_precisions, - # truncate_double=True, - use_explicit_typing=use_explicit_typing, - use_fp32_acc=use_fp32_acc, - device=DEVICE, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - offload_module_to_cpu=True, - min_block_size=args.min_block_size, - ) - - return trt_model - - -def print_outputs(backend_name, gen_tokens, tokenizer): - print(f"========= {backend_name} =========") - print( - f"{backend_name} model generated text: ", - tokenizer.decode(gen_tokens[0], skip_special_tokens=True), - ) - print("===================================") - - -def measure_perf(trt_model, input_signature, backend_name): - # Measure average time for 10 iterations - import timeit - - import numpy as np - - total_time = 0 - iterations = 10 - - print("Running warmup iteration...") - # Warmup run - _ = trt_model(*input_signature) - torch.cuda.synchronize() - - print(f"Measuring performance over {iterations} iterations...") - for i in range(iterations): - start_time = timeit.default_timer() - _ = trt_model(*input_signature) - torch.cuda.synchronize() - end_time = timeit.default_timer() - iter_time = end_time - start_time - total_time += iter_time - # print(f"Iteration {i+1}: {iter_time:.4f} seconds") - - avg_time = total_time / iterations - print( - f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" - ) - print( - f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" - ) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run inference on a model with random input values" - ) - arg_parser.add_argument( - "--model", - type=str, - default="meta-llama/Llama-3.2-1B-Instruct", - help="Name of LLM model", - ) - arg_parser.add_argument( - "--tokenizer", - type=str, - default="", - help="Name of LLM model tokenizer", - ) - arg_parser.add_argument( - "--prompt", type=str, default="What is parallel programming ?", help="Prompt" - ) - arg_parser.add_argument( - "--precision", - type=str, - default="FP16", - help="Precision to use in the model. Options: FP16, BF16, FP32", - ) - arg_parser.add_argument( - "--iterations", type=int, default=5, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--min_block_size", type=int, default=1, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--num_tokens", - type=int, - default=128, - help="no. of output tokens to be generated", - ) - arg_parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size used for benchmarking" - ) - arg_parser.add_argument( - "--isl", - type=int, - default=2048, - help="Input sequence length used for benchmarking", - ) - arg_parser.add_argument( - "--enable_pytorch_run", - action="store_true", - help="Enable pytorch run (default: False)", - ) - arg_parser.add_argument( - "--cache", - type=str, - default="", - help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", - ) - arg_parser.add_argument( - "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" - ) - arg_parser.add_argument( - "--debug", action="store_true", help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--benchmark", action="store_true", help="Enable benchmark (default: False)" - ) - - args = arg_parser.parse_args() - with torch.inference_mode(): - model = get_model(args) - - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) - - # Prepare input for benchmarking or evaluation - if args.benchmark: - input_ids = torch.randint( - 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 - ).to(model.device) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - else: - model_inputs = tokenizer(args.prompt, return_tensors="pt") - input_ids = model_inputs["input_ids"].to(DEVICE) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens - # Pyt - pyt_gen_tokens = None - pyt_timings = None - pyt_stats = None - if args.enable_pytorch_run: - pyt_gen_tokens = generate( - model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id - ) - if args.benchmark: - pyt_timings = time_generate( - generate, - model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - pyt_stats = recordStats( - "PyTorch", - pyt_timings, - args.precision, - batch_size=args.batch_size, - compile_time_s=None, - ) - - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 - elif args.cache == "dynamic": - import dynamic_cache - - trt_model = compile_torchtrt(model, input_ids, args) - - if ( - args.cache == "static_v1" - or args.cache == "static_v2" - or args.cache == "dynamic" - ): - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - trt_gen_tokens = generate_with_kv_cache( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - - if args.benchmark: - trt_timings = time_generate( - generate_with_kv_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - else: - trt_gen_tokens = generate( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - - if args.benchmark: - trt_stats = recordStats( - "TensorRT", - trt_timings, - args.precision, - batch_size=args.batch_size, - compile_time_s=None, - ) - - if not args.benchmark: - if args.enable_pytorch_run: - print_outputs("PyTorch", pyt_gen_tokens, tokenizer) - - print_outputs("TensorRT", trt_gen_tokens, tokenizer) - - if args.enable_pytorch_run: - print( - f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" - ) - - if args.benchmark: - if args.enable_pytorch_run: - print("=========PyTorch PERFORMANCE============ \n") - print(pyt_stats) - print("===================== \n") - print("=========TensorRT PERFORMANCE============ \n") - print(trt_stats) diff --git a/tools/llm/test_gemma.py b/tools/llm/test_gemma.py deleted file mode 100644 index a5dc420ade..0000000000 --- a/tools/llm/test_gemma.py +++ /dev/null @@ -1,389 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import argparse -import os -import sys -from contextlib import nullcontext - -import torch.nn as nn -import torch_tensorrt -from torch.testing._internal.common_utils import TestCase, run_tests -from transformers import AutoModelForCausalLM -from transformers.models.gemma3.configuration_gemma3 import Gemma3Config -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3Attention, - Gemma3DecoderLayer, -) - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * - -ATOL = 1e-5 -RTOL = 1e-5 - - -gemma3_model_name = "google/gemma-3-1b-it" -gemma3_model = ( - AutoModelForCausalLM.from_pretrained( - gemma3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ) - .eval() - .cuda() -) -GEMMA3_CONFIG = gemma3_model.config - - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print( - f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" - ) - - -def test_gemma3_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = gemma3_model.model.layers[0].self_attn.to(DTYPE) - - # gemma3 - hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda() - position_embeddings = ( - torch.randn((1, 5, 256), dtype=DTYPE).cuda(), - torch.randn((1, 5, 256), dtype=DTYPE).cuda(), - ) - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export( - model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug, - ) - trt_output = trt_model(hidden_states, position_embeddings, None) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - - -def test_gemma3_attention_with_static_cache(args): - - import static_cache_v2 - - DTYPE = torch.float32 - model = gemma3_model.model.layers[0].self_attn.to(DTYPE) - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() - position_embeddings = ( - torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), - torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), - ) - key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model(hidden_states, position_embeddings, None) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export( - model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes - ) - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[ - hidden_states, - position_embeddings, - None, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_python_runtime=True, - ) - - # Test Prefill - trt_output, _, key_cache, value_cache = trt_model( - hidden_states, - position_embeddings, - None, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ) - print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() - position_embeddings_curr = ( - torch.randn((1, 1, 256), dtype=DTYPE).cuda(), - torch.randn((1, 1, 256), dtype=DTYPE).cuda(), - ) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = ( - torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), - torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), - ) - - is_causal = False - out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) - out_trt, _, key_cache, value_cache = trt_model( - hidden_states_curr, - position_embeddings_curr, - None, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ) - out_pyt = out_no_cache[:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - - -def test_gemma3_decoder(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - model = gemma3_model.model.layers[0].to(DTYPE) - # model.self_attn.is_sliding = False - - # gemma3 - hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda() - position_embeddings_global = ( - torch.randn((1, 6, 256), dtype=DTYPE).cuda(), - torch.randn((1, 6, 256), dtype=DTYPE).cuda(), - ) - position_embeddings_local = ( - torch.randn((1, 6, 256), dtype=DTYPE).cuda(), - torch.randn((1, 6, 256), dtype=DTYPE).cuda(), - ) - - pyt_output = model( - hidden_states, position_embeddings_global, position_embeddings_local - ) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ( - {1: seq_len}, - ({1: seq_len}, {1: seq_len}), - ({1: seq_len}, {1: seq_len}), - ) - ep = torch.export.export( - model, - (hidden_states, position_embeddings_global, position_embeddings_local), - dynamic_shapes=dynamic_shapes, - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[ - hidden_states, - position_embeddings_global, - position_embeddings_local, - ], - enabled_precisions={torch.float32}, - debug=args.debug, - ) - trt_output = trt_model( - hidden_states, position_embeddings_global, position_embeddings_local - ) - - print( - f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" - ) - # breakpoint() - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - - -def test_gemma3_decoder_with_static_cache(args): - - class Gemma3DecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = GEMMA3_CONFIG - self.decoder = Gemma3DecoderLayer(config=self.config, layer_idx=0) - self.model = model - - def forward(self, hidden_states, position_embeddings): - return self.model(hidden_states, position_embeddings=position_embeddings) - - DTYPE = torch.float32 - model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE)) - - import static_cache_v2 - - # Inputs - ISL = 2048 - NUM_TOKENS = 128 - OSL = ISL + NUM_TOKENS - hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() - position_embeddings_global = ( - torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), - torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), - ) - position_embeddings_local = ( - torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), - torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), - ) - key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) - start_idx = 0 - end_idx = ISL - is_causal = True - - pyt_output = model( - hidden_states, position_embeddings_global, position_embeddings_local - ) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export( - model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - arg_inputs=[ - hidden_states, - position_embeddings, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ], - enabled_precisions={torch.float32}, - disable_tf32=True, - debug=args.debug, - # offload_module_to_cpu=True, - use_python_runtime=True, - ) - - # Test Prefill - trt_output, key_cache, value_cache = trt_model( - hidden_states, - position_embeddings, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ) - print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") - - # Test Generate - for start_idx in range(2048, 2176): - end_idx = start_idx + 1 - hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() - position_embeddings_curr = ( - torch.randn((1, 1, 256), dtype=DTYPE).cuda(), - torch.randn((1, 1, 256), dtype=DTYPE).cuda(), - ) - # Concatenate the current hidden_states with the previous ones - hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) - position_embeddings_full = ( - torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), - torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), - ) - - is_causal = False - out_no_cache = model(hidden_states_full, position_embeddings_full) - - out_trt, key_cache, value_cache = trt_model( - hidden_states_curr, - position_embeddings_curr, - key_cache, - value_cache, - start_idx, - end_idx, - is_causal, - ) - out_pyt = out_no_cache[0][:, -1:, :] - print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") - hidden_states = hidden_states_full - position_embeddings = position_embeddings_full - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", action="store_true", help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--precision", - type=str, - default="FP16", - help="Precision to use in the model. Options: FP16, BF16, FP32", - ) - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_gemma3_attention(args) - # test_gemma3_attention_with_static_cache(args) - test_gemma3_decoder(args) - # test_gemma3_decoder_with_static_cache(args) diff --git a/tools/llm/test_qwen3.py b/tools/llm/test_qwen3.py deleted file mode 100644 index e46f17050a..0000000000 --- a/tools/llm/test_qwen3.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -import argparse -import os -import sys -from contextlib import nullcontext - -import torch.nn as nn -import torch_tensorrt -from torch.testing._internal.common_utils import TestCase, run_tests -from transformers import AutoModelForCausalLM -from transformers.models.qwen3.configuration_qwen3 import Qwen3Config -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer - -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * - -ATOL = 1e-5 -RTOL = 1e-5 - - -qwen3_model_name = "Qwen/Qwen3-0.6B" -qwen3_model = ( - AutoModelForCausalLM.from_pretrained( - qwen3_model_name, - use_cache=False, - attn_implementation="sdpa", - num_hidden_layers=1, - ) - .eval() - .cuda() -) -QWEN_CONFIG = qwen3_model.config - - -def print_diff(tensor1, tensor2, prefix=""): - """ - Print the diff between two tensors - """ - print( - f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" - ) - - -def test_qwen_attention(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False - else: - enabled_precisions = {torch.float32} - - model = qwen3_model.model.layers[0].self_attn.to(DTYPE) - # qwen2.5 - hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() - position_embeddings = ( - torch.randn((1, 5, 128), dtype=DTYPE).cuda(), - torch.randn((1, 5, 128), dtype=DTYPE).cuda(), - ) - - pyt_output = model(hidden_states, position_embeddings, None) - - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) - ep = torch.export.export( - model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[hidden_states, position_embeddings, None], - enabled_precisions=enabled_precisions, - disable_tf32=True, - use_fp32_acc=use_fp32_acc, - use_explicit_typing=use_explicit_typing, - debug=args.debug, - ) - trt_output = trt_model(hidden_states, position_embeddings, None) - - if isinstance(pyt_output, tuple): - print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - else: - print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") - assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) - - -def test_qwen3_decoder(args): - - class QwenDecoderLayerBlock(nn.Module): - def __init__(self, model): - super().__init__() - self.config = QWEN_CONFIG - self.model = model - - def forward(self, hidden_states, position_ids, position_embeddings): - return self.model( - hidden_states, - position_ids=position_ids, - position_embeddings=position_embeddings, - ) - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - model = QwenDecoderLayerBlock(qwen3_model.model.layers[0].to(DTYPE)) - # qwen3 - hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() - position_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() - position_embeddings = ( - torch.randn((1, 5, 128), dtype=DTYPE).cuda(), - torch.randn((1, 5, 128), dtype=DTYPE).cuda(), - ) - - pyt_output = model(hidden_states, position_ids, position_embeddings) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}, ({1: seq_len}, {1: seq_len})) - ep = torch.export.export( - model, - (hidden_states, position_ids, position_embeddings), - dynamic_shapes=dynamic_shapes, - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[hidden_states, position_ids, position_embeddings], - enabled_precisions={torch.float32}, - debug=args.debug, - ) - trt_output = trt_model(hidden_states, position_ids, position_embeddings) - - print( - f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" - ) - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - - -def test_qwen3_model(args): - - DTYPE = torch.float32 - if args.precision == "FP16": - DTYPE = torch.float16 - elif args.precision == "BF16": - DTYPE = torch.bfloat16 - - model = qwen3_model.model.to(DTYPE) - # qwen3 - input_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() - position_ids = ( - torch.arange(input_ids.shape[1], dtype=torch.int64).cuda().unsqueeze(0) - ) - - pyt_output = model(input_ids, position_ids) - seq_len = torch.export.Dim("seq_len", min=2, max=2176) - dynamic_shapes = ({1: seq_len}, {1: seq_len}) - ep = torch.export.export( - model, (input_ids, position_ids), dynamic_shapes=dynamic_shapes - ) - - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[input_ids, position_ids], - enabled_precisions={torch.float32}, - use_python_runtime=True, - disable_tf32=True, - debug=args.debug, - ) - # breakpoint() - trt_output = trt_model(input_ids, position_ids) - - print( - f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" - ) - print( - f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}" - ) - print( - f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[2] - trt_output[2]))}" - ) - assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run test cases for llama attention and decoder" - ) - arg_parser.add_argument( - "--debug", action="store_true", help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--precision", - type=str, - default="FP32", - help="Precision to use in the model. Options: FP16, BF16, FP32", - ) - args = arg_parser.parse_args() - with torch.inference_mode(): - # test_qwen_attention(args) - # test_qwen3_decoder(args) - test_qwen3_model(args) From 5eabf650bbf53847050b9b46516070d5ba9fface Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Jun 2025 00:21:12 +0000 Subject: [PATCH 25/30] chore: updates --- examples/dynamo/weight_streaming_example.py | 38 ++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py index e1076a9e75..601292ba95 100644 --- a/examples/dynamo/weight_streaming_example.py +++ b/examples/dynamo/weight_streaming_example.py @@ -32,7 +32,43 @@ import torch import torch_tensorrt from transformers import AutoModelForCausalLM -from utils import export_llm + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep def time_generate(model, inputs, output_seq_length, iterations=10): From 806616d85b695901ee917ad8305d7d1839d83385 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Jun 2025 06:13:00 +0000 Subject: [PATCH 26/30] chore: Add README.md --- tools/llm/README.md | 66 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tools/llm/README.md diff --git a/tools/llm/README.md b/tools/llm/README.md new file mode 100644 index 0000000000..3fd55bc060 --- /dev/null +++ b/tools/llm/README.md @@ -0,0 +1,66 @@ +# Optimizing LLMs in Torch-TensorRT + +This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions. + +### Key Features + +- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. +- **Precision Modes:** Supports FP16, BF16, and FP32. +- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. +- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. +- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT. + + +### Supported Models + +We have officially verified support for the following models: + +| Model Series | HF Model Card | Precision | KV Cache Supported ? | +|--------------|---------------|-----------|-------------------| +| GPT-2 | gpt2
gpt2-medium | FP16, FP32 | Yes | +| LLaMA 2 | meta-llama/Llama-2-7b-chat-hf | FP16, FP32 | Yes | +| LLaMA 3.1 | meta-llama/Llama-3.1-8B-Instruct | FP16, FP32 | Yes | +| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes | +| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | + + +### Usage + +The main entry point is : `run_llm.py` + +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark +``` + +#### Key Arguments + +- `--model`: Name or path of the HuggingFace LLM. +- `--tokenizer`: (Optional) Tokenizer name; defaults to model. +- `--prompt`: Input prompt for generation. +- `--precision`: Precision mode (`FP16`, `FP32`). +- `--num_tokens`: Number of output tokens to generate. +- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). +- `--benchmark`: Enable benchmarking mode. +- `--enable_pytorch_run`: Also run and compare PyTorch baseline. + +### Caching Strategies + +- **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse. +- **No Cache:** Standard autoregressive decoding. + +Please read our tutorial on how static cache is implemented. + +## Extension + +This codebase can be extended to +- Add new models by specifying their HuggingFace name. +- Implement new cache strategies by adding FX graph passes. +- Customize SDPA conversion for new attention mechanisms. + +## Limitations +- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet. + +## Requirements + +- Torch-TensorRT 2.8.0 +- Transformers v4.52.3 \ No newline at end of file From c2cbd5a6c97c7099f0a7f5f052374f5eade92899 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Jun 2025 11:41:24 -0700 Subject: [PATCH 27/30] chore: add docs --- docsrc/index.rst | 6 ++---- examples/dynamo/aot_plugin.py | 9 +++++++++ tools/llm/test_static_cache.py | 22 +++++++++++++++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docsrc/index.rst b/docsrc/index.rst index 67fbdc56f5..4d28d77640 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -140,11 +140,10 @@ Model Zoo * :ref:`torch_compile_resnet` * :ref:`torch_compile_transformer` * :ref:`torch_compile_stable_diffusion` +* :ref:`compile_hf_models` * :ref:`torch_compile_gpt2` * :ref:`torch_export_gpt2` -* :ref:`torch_export_llama2` * :ref:`torch_export_sam2` -* :ref:`torch_export_flux_dev` * :ref:`notebooks` .. toctree:: @@ -155,11 +154,10 @@ Model Zoo tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion + tutorials/compile_hf_models tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion tutorials/_rendered_examples/dynamo/torch_compile_gpt2 - tutorials/_rendered_examples/dynamo/torch_export_gpt2 - tutorials/_rendered_examples/dynamo/torch_export_llama2 tutorials/_rendered_examples/dynamo/torch_export_sam2 tutorials/_rendered_examples/dynamo/torch_export_flux_dev tutorials/notebooks diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py index 7e8204c165..4aa49e4eca 100644 --- a/examples/dynamo/aot_plugin.py +++ b/examples/dynamo/aot_plugin.py @@ -1,3 +1,12 @@ +""" +.. _aot_plugin: + +AOT Plugin +========== + +This example demonstrates how to use an AOT plugin in Torch-TensorRT. +""" + import argparse from typing import Tuple, Union diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py index af72f93b67..c27c6b2284 100644 --- a/tools/llm/test_static_cache.py +++ b/tools/llm/test_static_cache.py @@ -58,12 +58,24 @@ class StaticCacheModel(nn.Module): def __init__(self): super().__init__() - # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): - # new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) - # new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) - # out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + def forward( + self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ): + new_key_cache = torch.cat( + (key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2 + ) + new_value_cache = torch.cat( + (value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2 + ) + out = torch._C._nn.scaled_dot_product_attention( + q, + new_key_cache[:, :, :end_idx, :], + new_value_cache[:, :, :end_idx, :], + dropout_p=0.0, + is_causal=is_causal, + ) - # return out, new_key_cache, new_value_cache + return out, new_key_cache, new_value_cache def forward( self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True From 81114f8caa606d55dd83c97197e9bf330c181f2d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Jun 2025 13:31:47 -0700 Subject: [PATCH 28/30] chore: Add a tutorial --- docsrc/tutorials/compile_hf_models.rst | 218 +++++++++++++++++++++++++ tools/llm/run_llm.py | 1 - tools/llm/test_static_cache.py | 12 +- 3 files changed, 223 insertions(+), 8 deletions(-) create mode 100644 docsrc/tutorials/compile_hf_models.rst diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst new file mode 100644 index 0000000000..d173780040 --- /dev/null +++ b/docsrc/tutorials/compile_hf_models.rst @@ -0,0 +1,218 @@ +.. _compile_hf_models: + +Compiling LLM models from Huggingface +====================================== + +This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. +The code is available in the `tools/llm `_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance. + +.. note:: + This is an **experimental release** and APIs may change in future versions. + +.. note:: + The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm `_ directory. + +Overview of tools/llm Directory +------------------------------- + +The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface: + +* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking +* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization +* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass. +* **Testing Components**: Model-specific test files for validation +* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations + +Supported Models +---------------- +We have officially verified support for the following LLM families: + +.. list-table:: + :widths: 20 40 20 20 + :header-rows: 1 + + * - Model Series + - HuggingFace Model Card + - Precision + - KV Cache Support ? + * - GPT-2 + - gpt2 + - FP16, FP32 + - Yes + * - LLaMA 2 + - meta-llama/Llama-2-7b-chat-hf + - FP16, FP32 + - Yes + * - LLaMA 3.1 + - meta-llama/Llama-3.1-8B-Instruct + - FP16, FP32 + - Yes + * - LLaMA 3.2 + - | meta-llama/Llama-3.2-1B-Instruct + | meta-llama/Llama-3.2-3B-Instruct + - FP16, FP32 + - Yes + * - Qwen 2.5 + - | Qwen/Qwen2.5-0.5B-Instruct + | Qwen/Qwen2.5-1.5B-Instruct + | Qwen/Qwen2.5-4B-Instruct + | Qwen/Qwen2.5-7B-Instruct + - FP16, FP32 + - Yes + +Getting Started with run_llm.py +------------------------------- + +The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking. + +Basic Usage +^^^^^^^^^^^ + +.. code-block:: bash + + python tools/llm/run_llm.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --prompt "What is parallel programming?" \ + --precision FP16 \ + --num_tokens 128 \ + --cache static_v2 \ + --benchmark + +Key Arguments +^^^^^^^^^^^^^ + +* ``--model``: Name or path of the HuggingFace LLM +* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name +* ``--prompt``: Input prompt for text generation +* ``--precision``: Precision mode (``FP16``, ``FP32``) +* ``--num_tokens``: Number of output tokens to generate +* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching) +* ``--benchmark``: Enable benchmarking mode for performance comparison +* ``--enable_pytorch_run``: Also run and compare PyTorch baseline + + +Other Usage Examples +^^^^^^^^^^^^^^^^^^^^ +.. code-block:: bash + + # Compare different models performance + python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run + + # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128 + python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128 + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128 + + # Test different caching approaches + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1 + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2 + + # Compare FP16 vs FP32 performance + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark + + +KV Caching in Torch-TensorRT +--------------------------------- + +We provide two versions of static KV caching: `static_cache_v1 `_ and `static_cache_v2 `_. +In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory. +The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config. + +Static Cache v1 +^^^^^^^^^^^^^^^^ + +The ``static_cache_v1.py`` implements KV cache in the model graph as follows: + +.. code-block:: python + + class StaticCacheV1Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + # Concatenate new key/value pairs with existing cache + new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + + # Compute attention using the updated cache + attn_output = torch._C._nn.scaled_dot_product_attention( + q, + new_key_cache[:, :, :end_idx, :], + new_value_cache[:, :, :end_idx, :], + dropout_p=0.0, + is_causal=is_causal + ) + + return attn_output, new_key_cache, new_value_cache + +In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index. +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. + +.. note:: + The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length. + For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate. + + +Static Cache v2 +^^^^^^^^^^^^^^^^ + +The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows: + +.. code-block:: python + + class StaticCacheV2Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + attn_output = torch._C._nn.scaled_dot_product_attention( + q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal + ) + + return attn_output, new_key_cache, new_value_cache + +In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value. +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. +The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``. + +After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``. +The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs. + +Generating Outputs +------------------- +We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. +There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. + +The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache. +The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``. +We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model. + +SDPA Converter (sdpa_converter.py) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* Converts scaled dot-product attention operation using TRT Python API. +* Supports causal and standard self-attention. + +SDPA Registration (register_sdpa.py) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``. +* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation. + + +Limitations and Known Issues +---------------------------- + +* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported +* Some model architectures (e.g. Phi-4) have issues with exporting the torch model. + +Requirements +^^^^^^^^^^^^ + +* Torch-TensorRT 2.8.0 or later +* Transformers v4.52.3 \ No newline at end of file diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index de8ab9d92a..7a98d2a2c0 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -44,7 +44,6 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", - # num_hidden_layers=1 ) .eval() .cuda() diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py index c27c6b2284..603f84d3a6 100644 --- a/tools/llm/test_static_cache.py +++ b/tools/llm/test_static_cache.py @@ -67,7 +67,7 @@ def forward( new_value_cache = torch.cat( (value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2 ) - out = torch._C._nn.scaled_dot_product_attention( + attn_output = torch._C._nn.scaled_dot_product_attention( q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], @@ -75,24 +75,22 @@ def forward( is_causal=is_causal, ) - return out, new_key_cache, new_value_cache + return attn_output, new_key_cache, new_value_cache def forward( self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True ): - concat_keys = torch.cat( - (key_cache[:, :, :start_idx, :], k), dim=2 - ) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ] + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) new_value_cache = torch.cat( (concat_values, value_cache[:, :, end_idx:, :]), dim=2 ) - out = torch._C._nn.scaled_dot_product_attention( + attn_output = torch._C._nn.scaled_dot_product_attention( q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal ) - return out, new_key_cache, new_value_cache + return attn_output, new_key_cache, new_value_cache def eager_sdpa( From 57da513bb83cbd4621cb4d24e766d7b7e729eda8 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 13 Jun 2025 13:54:18 -0700 Subject: [PATCH 29/30] chore: fix model name --- docsrc/tutorials/compile_hf_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst index d173780040..f6da87b145 100644 --- a/docsrc/tutorials/compile_hf_models.rst +++ b/docsrc/tutorials/compile_hf_models.rst @@ -55,7 +55,7 @@ We have officially verified support for the following LLM families: * - Qwen 2.5 - | Qwen/Qwen2.5-0.5B-Instruct | Qwen/Qwen2.5-1.5B-Instruct - | Qwen/Qwen2.5-4B-Instruct + | Qwen/Qwen2.5-3B-Instruct | Qwen/Qwen2.5-7B-Instruct - FP16, FP32 - Yes From deec75fcd2c16fca60b45eaef0c3b4b93cf6e31d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 17 Jun 2025 01:01:15 +0000 Subject: [PATCH 30/30] chore: address review comments --- .../dynamo/conversion/_TRTInterpreter.py | 3 +- tools/llm/cache_utils.py | 41 ++-------- tools/llm/run_llm.py | 75 +++++++++++-------- tools/llm/static_cache_v1.py | 12 +-- tools/llm/static_cache_v2.py | 12 +-- tools/llm/{ => torchtrt_ext}/register_sdpa.py | 3 +- .../llm/{ => torchtrt_ext}/sdpa_converter.py | 0 7 files changed, 67 insertions(+), 79 deletions(-) rename tools/llm/{ => torchtrt_ext}/register_sdpa.py (99%) rename tools/llm/{ => torchtrt_ext}/sdpa_converter.py (100%) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index bb1a77b4eb..f88c8fc53d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -887,10 +887,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: else: return converter(self.ctx, target, args, kwargs, self._cur_node_name) - def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: + def get_attr(self, target: str, args: Any, kwargs: Any) -> torch.Tensor: with _disable_current_modes(), unset_fake_temporarily(): frozen_attr = self.fetch_attr(target) - if isinstance(frozen_attr, torch.nn.Parameter): constant_tensor = frozen_attr.data else: diff --git a/tools/llm/cache_utils.py b/tools/llm/cache_utils.py index 7089d9a220..d25e5bb40e 100644 --- a/tools/llm/cache_utils.py +++ b/tools/llm/cache_utils.py @@ -12,44 +12,19 @@ from torch.utils._pytree import _LEAF_SPEC -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( - torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True -) -def cond_converter( - ctx: torch_tensorrt.dynamo.conversion.ConversionContext, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: str, -) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: +def get_kv_nodes(gm): """ - Converter for torch.ops.higher_order.cond operation to TensorRT. + Extract key and value nodes from scaled dot-product attention operations in the graph. - This function handles the conversion of PyTorch's conditional operation to TensorRT. - The conditional operation selects between two tensors based on a boolean predicate. + This function searches through the graph for scaled_dot_product_attention operations + and extracts the key and value tensor nodes from each operation's arguments. Args: - ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context - target (Target): The target operation to convert - args (Tuple[Argument, ...]): The arguments to the operation - kwargs (Dict[str, Argument]): The keyword arguments to the operation - name (str): The name to give to the TensorRT layer + gm: A torch.fx.GraphModule containing the computational graph Returns: - Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) - """ - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = args[0], args[1], args[2] - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - output = output_layer.get_output(0) - - return output - - -def get_kv_nodes(gm): - """ - Get the key and value nodes from the graph. + List[Tuple[Node, Node]]: A list of tuples, where each tuple contains + (key_node, value_node) from a scaled dot-product attention operation """ kv_nodes = [] for node in gm.graph.nodes: @@ -127,7 +102,7 @@ def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: return random_tensors -def add_graph_input( +def _add_graph_input( gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None ) -> Node: """Add a graph input to the given GraphModule and return the newly created node. diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7a98d2a2c0..5f42cb3e60 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -1,10 +1,10 @@ """ -.. _torch_export_gpt2: +.. _run_llm: -Compiling GPT2 using the dynamo backend +Running LLM inference with Torch-TensorRT ========================================================== -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +This script illustrates Torch-TensorRT workflow with dynamo backend on popular LLM models. """ import argparse @@ -18,12 +18,11 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch import torch_tensorrt -from register_sdpa import * +from torchtrt_ext import register_sdpa from transformers import AutoModelForCausalLM, AutoTokenizer from utils import ( export_llm, generate, - generate_with_dynamic_cache, generate_with_static_cache, recordStats, time_generate, @@ -33,17 +32,29 @@ def get_model(args): + """ + Load and configure the language model for inference. + + This function loads a pre-trained causal language model using the specified + model name and configures it with the appropriate precision and settings + for inference. + + Args: + args: Parsed command line arguments containing: + - model (str): Name or path of the model to load + - precision (str): Precision to use ("FP16", "BF16", or "FP32") + + Returns: + torch.nn.Module: The loaded and configured model ready for inference, + moved to CUDA device with the specified precision + """ with torch.no_grad(): - # Supported list of models: - # - meta-llama/Llama-3.2-1B-Instruct - # - meta-llama/Llama-3.2-3B-Instruct - # - meta-llama/Llama-3.1-8B-Instruct - # - Qwen/Qwen2.5-1.5B-Instruct model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, attn_implementation="sdpa", + num_hidden_layers=1, ) .eval() .cuda() @@ -59,6 +70,26 @@ def get_model(args): def compile_torchtrt(model, input_ids, args): + """ + Compile a PyTorch model to TensorRT using torch_tensorrt.dynamo.compile. + + This function exports the given model to a TorchScript representation and then + compiles it to TensorRT for optimized inference. The compilation process includes + precision-specific optimizations and various performance tuning parameters. + + Args: + model (torch.nn.Module): The PyTorch model to compile + input_ids (torch.Tensor): Input token IDs tensor used for model export + args: Parsed command line arguments containing: + - num_tokens (int): Number of tokens to generate (used for max sequence length) + - precision (str): Precision to use ("FP16", "BF16", or "FP32") + - debug (bool): Whether to enable debug logging + - min_block_size (int): Minimum block size for TensorRT compilation + + Returns: + torch_tensorrt.dynamo.TorchTensorRTModule: The compiled TensorRT model ready + for optimized inference + """ max_seq_len = input_ids.shape[1] + args.num_tokens ep = export_llm(model, input_ids, max_seq_len=max_seq_len) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) @@ -191,7 +222,7 @@ def measure_perf(trt_model, input_signature, backend_name): "--cache", type=str, default="", - help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", + help="Type of KV cache to use. Options: static_v1, static_v2", ) arg_parser.add_argument( "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" @@ -249,12 +280,10 @@ def measure_perf(trt_model, input_signature, backend_name): if args.cache == "static_v1": # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 + from torchtrt_ext import static_cache_v1 if args.cache == "static_v2": # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 - elif args.cache == "dynamic": - import dynamic_cache + from torchtrt_ext import static_cache_v2 # Compile the model with Torch-TensorRT trt_model = compile_torchtrt(model, input_ids, args) @@ -281,22 +310,6 @@ def measure_perf(trt_model, input_signature, backend_name): tokenizer.eos_token_id, iterations=args.iterations, ) - elif args.cache == "dynamic": - trt_gen_tokens = generate_with_dynamic_cache( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate_with_dynamic_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) else: trt_gen_tokens = generate( trt_model, diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index a87495953d..b60396c08b 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -3,7 +3,7 @@ import torch import torch.utils._pytree as pytree -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes from torch.fx import Node from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -89,14 +89,14 @@ def get_static_tensor(tensor: torch.Tensor): k_val = get_static_tensor(k_val) v_val = get_static_tensor(v_val) - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + # Add new inputs using _add_graph_input + k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val) kv_inputs.append((k_input, v_input)) # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1)) # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py index ad386d39f2..4634b79a52 100644 --- a/tools/llm/static_cache_v2.py +++ b/tools/llm/static_cache_v2.py @@ -3,7 +3,7 @@ import torch import torch.utils._pytree as pytree -from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes from torch.fx import Node from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -89,14 +89,14 @@ def get_static_tensor(tensor: torch.Tensor): k_val = get_static_tensor(k_val) v_val = get_static_tensor(v_val) - # Add new inputs using add_graph_input - k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) - v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + # Add new inputs using _add_graph_input + k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val) kv_inputs.append((k_input, v_input)) # Add start_idx and end_idx as inputs - start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) - end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1)) # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] diff --git a/tools/llm/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py similarity index 99% rename from tools/llm/register_sdpa.py rename to tools/llm/torchtrt_ext/register_sdpa.py index c3c76e0f2d..90a00a5798 100644 --- a/tools/llm/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -4,7 +4,6 @@ from typing import Callable, Sequence, Tuple import torch -from sdpa_converter import * from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS @@ -15,6 +14,8 @@ clean_up_graph_after_modifications, ) +from .sdpa_converter import * + logger = logging.getLogger(__name__) # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention diff --git a/tools/llm/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py similarity index 100% rename from tools/llm/sdpa_converter.py rename to tools/llm/torchtrt_ext/sdpa_converter.py