-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Refactor LLM model zoo and add KV cache support #3527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
from torch.utils._pytree import _LEAF_SPEC | ||
|
||
|
||
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to upstream this eventually?
@@ -62,25 +62,15 @@ def scaled_dot_product_attention( | |||
) -> TRTTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put the torch_trt extensions in a subdirectory all together. like //tools/llm/torch_trt_ext/
@@ -0,0 +1,193 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we create a tests/py/tools/llm
?
) -> Node: | ||
"""Add a graph input to the given GraphModule and return the newly created node. | ||
|
||
NOTE: function does NOT do any graph canonicalization. This is left to the user! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you dont want users to hit this directly consider prefixing the method name with _
like _add_graph_input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. just some questions
@@ -837,6 +859,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |||
continue | |||
submodule_node_dict[node.name] = node | |||
|
|||
preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should set spec for partitioned_module
's children?
from utils import export_llm | ||
|
||
|
||
def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not directly call tools/llm/utils.py?
Description
This PR redesigns our LLM model compilation, unifies it, fixes output mismatch and performance issues. This PR also implements KV caching using native TensorRT.
Fixes # (issue)
Type of change
Checklist: