-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
LFM2 #20797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
paulpak58
wants to merge
23
commits into
vllm-project:main
Choose a base branch
from
paulpak58:lfm2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,081
−92
Open
LFM2 #20797
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
3a50223
[cmake] ignore nvToolsExt for cuda-12.9
paulpak58 6fd86d2
[model_executor][models] LFM2 architecture
paulpak58 aaf7df1
[configs] use layer_types from huggingface hybrids >= 4.54.0.dev0
paulpak58 6c80caf
[model_runner][v1] ShortConvSpec for ShortConv layers; compatibility …
paulpak58 d17c95f
[configs] need to detect full_attention key in layer_types for transf…
paulpak58 1bc8835
[layers][conv] update ShortConv layer to be compatible with triton ca…
paulpak58 e550362
[transformers][ovis] tmp: AIMv2Config doesn't need to be registered o…
paulpak58 05af65a
[models][lfm2] LFM2->Lfm2 to match config
paulpak58 40d81e9
[merge] upstream @ 5bac61362b6718b90e708e7b212e7fcbe7d59fa3
paulpak58 7241660
[v1][cache] add support for conv cache shapes
paulpak58 3d3be6a
[v1][config] generalize HybridAttentionMambaModelConfig
paulpak58 b2447dd
[merge] upstream main @ 19c863068b2d70a452bde25318dbcf04f274ad19
paulpak58 46902dc
[layer][conv] update conv metadata in causal_conv1d
paulpak58 260e3fe
[misc] format + cleanup
paulpak58 1dff6e1
[layers][conv] fix minor discprencies in decode conv
paulpak58 30621b4
[merge] upstream @ a0f8a7964694a6077689b242b5eca95de392d4bb
paulpak58 9c3edab
[layers][conv] fix ordering of prefill/decode tokens in conv layer
paulpak58 63cd12b
[tests] register LFM2 in test models
paulpak58 9af96d9
[tests][hybrid] include LFM2 in V1 Hybrids + include unsupported V1 a…
paulpak58 7577e89
[docs] update supported_models + v1 guide
paulpak58 1ff0c89
[misc] fix pre-commit checks
paulpak58 b425c0d
[model_executor] remap mamba V1 utils to static_cache + cleanup
paulpak58 80a2f3a
[misc] minor: fix format
paulpak58 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
|
||
from typing import Any, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm import envs | ||
from vllm.config import get_current_vllm_config | ||
from vllm.forward_context import get_forward_context | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_world_size | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
MergedColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, | ||
update_metadata) | ||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( | ||
causal_conv1d_fn, causal_conv1d_update) | ||
from vllm.attention.backends.abstract import AttentionMetadata | ||
from vllm.model_executor.models.conv_cache import ConvCacheParams | ||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata | ||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata | ||
|
||
|
||
@CustomOp.register("short_conv") | ||
class ShortConv(CustomOp): | ||
|
||
def __init__(self, config, dim: int, layer_idx: int, prefix: str = ""): | ||
super().__init__() | ||
self.config = config | ||
self.layer_idx = layer_idx | ||
self.conv_dim = dim | ||
self.L_cache = config.conv_L_cache | ||
self.bias = config.conv_bias | ||
|
||
self.conv = ColumnParallelLinear( | ||
input_size=self.L_cache, | ||
output_size=dim, | ||
bias=self.bias, | ||
prefix=f"{prefix}.conv1d", | ||
) | ||
# unsqueeze to fit conv1d weights shape into the linear weights shape. | ||
# Can't do this in `weight_loader` since it already exists in | ||
# `ColumnParallelLinear` and `set_weight_attrs` | ||
# doesn't allow to override it | ||
self.conv.weight.data = self.conv.weight.data.unsqueeze(1) | ||
|
||
self.in_proj = MergedColumnParallelLinear( | ||
input_size=dim, | ||
output_sizes=[dim] * 3, | ||
bias=self.bias, | ||
prefix=f"{prefix}.in_proj", | ||
) | ||
self.out_proj = RowParallelLinear( | ||
input_size=dim, | ||
output_size=dim, | ||
bias=self.bias, | ||
prefix=f"{prefix}.out_proj", | ||
) | ||
|
||
if envs.VLLM_USE_V1: | ||
compilation_config = get_current_vllm_config().compilation_config | ||
if prefix in compilation_config.static_forward_context: | ||
raise ValueError(f"Duplicate layer name: {prefix}") | ||
compilation_config.static_forward_context[prefix] = self | ||
# The outer list is for v0 PP virtual engine. Though this code path | ||
# only runs for v1, we have to do this to unify with the interface | ||
# of Attention + v0 PP. | ||
# The inner tuple is (conv_state,) | ||
self.kv_cache = [(torch.tensor([]))] | ||
|
||
# For compatibility with MambaSpec utils | ||
self.chunk_size = 1 | ||
self.prefix = prefix | ||
|
||
def forward_native(self, hidden_states: torch.Tensor, | ||
conv_cache_params: ConvCacheParams) -> torch.Tensor: | ||
pass | ||
|
||
def forward_cuda( | ||
self, | ||
hidden_states: torch.Tensor, | ||
conv_cache_params: ConvCacheParams, | ||
conv_metadata: Mamba2Metadata, | ||
) -> torch.Tensor: | ||
forward_context = get_forward_context() | ||
# mamba2_metadata contains metadata necessary for the mamba2 triton | ||
# kernels to operate in continuous batching and in chunked prefill | ||
# modes; they are computed at top-level model forward since they | ||
# stay the same and reused for all mamba layers in the same iteration | ||
attn_metadata: Optional[AttentionMetadata] = get_forward_context().attn_metadata | ||
if envs.VLLM_USE_V1: | ||
if attn_metadata is not None: | ||
assert isinstance(attn_metadata, dict) | ||
attn_metadata = attn_metadata[self.prefix] | ||
conv_metadata = attn_metadata | ||
assert isinstance(attn_metadata, Mamba2AttentionMetadata) | ||
self_kv_cache = self.kv_cache[forward_context.virtual_engine] | ||
conv_state = self_kv_cache[0].transpose(-1, -2) | ||
state_indices_tensor = attn_metadata.state_indices_tensor | ||
has_initial_states_p = attn_metadata.has_initial_states | ||
# prep_initial_states = attn_metadata.prep_initial_states | ||
# chunk_size = attn_metadata.chunk_size | ||
# seq_idx_p = attn_metadata.seq_idx | ||
# chunk_indices_p = attn_metadata.chunk_indices | ||
# chunk_offsets_p = attn_metadata.chunk_offsets | ||
paulpak58 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
conv_state = conv_cache_params.conv_state | ||
state_indices_tensor = conv_cache_params.state_indices_tensor | ||
has_initial_states_p = conv_metadata.has_initial_states | ||
# prep_initial_states = conv_metadata.prep_initial_states | ||
# chunk_size = conv_metadata.chunk_size | ||
# seq_idx_p = conv_metadata.seq_idx | ||
# chunk_indices_p = conv_metadata.chunk_indices | ||
# chunk_offsets_p = conv_metadata.chunk_offsets | ||
paulpak58 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
BCx, _ = self.in_proj(hidden_states) | ||
|
||
B, C, x = BCx.chunk(3, dim=-1) | ||
|
||
conv_weights = self.conv.weight.view(self.conv.weight.size(0), | ||
self.conv.weight.size(2)) | ||
|
||
if envs.VLLM_USE_V1 and attn_metadata is None: | ||
# V1 profile run | ||
Bx = (B * x).contiguous() | ||
hidden_states = C * Bx | ||
contextualized_states, _ = self.out_proj(hidden_states) | ||
return contextualized_states | ||
|
||
num_prefills = attn_metadata.num_prefills # request count | ||
num_decodes = attn_metadata.num_decode_tokens # token count (=request) | ||
paulpak58 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count | ||
has_prefill = num_prefills > 0 | ||
has_decode = num_decodes > 0 | ||
|
||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill | ||
# Separate prefill and decode by splitting varlen input | ||
# Split along token dimension | ||
if envs.VLLM_USE_V1: | ||
B_d, B_p = torch.split( | ||
B, | ||
[num_decodes, num_prefill_tokens], | ||
dim=0, | ||
) | ||
C_d, C_p = torch.split( | ||
C, | ||
[num_decodes, num_prefill_tokens], | ||
dim=0, | ||
) | ||
x_d, x_p = torch.split( | ||
x, | ||
[num_decodes, num_prefill_tokens], | ||
dim=0, | ||
) | ||
# Split along batch dimension | ||
state_indices_tensor_d, state_indices_tensor_p = torch.split( | ||
state_indices_tensor, | ||
[num_decodes, num_prefills], | ||
dim=0, | ||
) | ||
query_start_loc_p = ( | ||
attn_metadata.query_start_loc[-num_prefills - 1:] - | ||
num_decodes if has_prefill else None) | ||
else: | ||
B_p, B_d = torch.split( | ||
B, | ||
[num_prefill_tokens, num_decodes], | ||
dim=0, | ||
) | ||
C_p, C_d = torch.split( | ||
C, | ||
[num_prefill_tokens, num_decodes], | ||
dim=0, | ||
) | ||
x_p, x_d = torch.split( | ||
x, | ||
[num_prefill_tokens, num_decodes], | ||
dim=0, | ||
) | ||
# Split along batch dimension | ||
state_indices_tensor_p, state_indices_tensor_d = torch.split( | ||
conv_cache_params.state_indices_tensor, | ||
[num_prefills, num_decodes], | ||
dim=0, | ||
) | ||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + | ||
1] | ||
if has_prefill else None) | ||
|
||
conv_output_list = [] | ||
|
||
if has_prefill: | ||
Bx_p = (B_p * x_p).transpose(0, 1) | ||
if conv_metadata.cu_seqlen is None: | ||
conv_metadata = update_metadata(Bx_p, query_start_loc_p, | ||
conv_metadata) | ||
Bx = causal_conv1d_fn( | ||
Bx_p, | ||
conv_weights, | ||
self.conv.bias, | ||
activation=None, | ||
conv_states=conv_state, | ||
has_initial_state=has_initial_states_p, | ||
cache_indices=state_indices_tensor_p, | ||
metadata=conv_metadata, | ||
query_start_loc=query_start_loc_p).transpose( | ||
0, 1)[:num_prefill_tokens] | ||
|
||
C_p = C_p.view(1, num_prefill_tokens, -1) | ||
y = C_p * Bx | ||
conv_output_list.append(y.view(num_prefill_tokens, -1)) | ||
|
||
if has_decode: | ||
Bx_d = (B_d * x_d).contiguous() | ||
Bx = causal_conv1d_update( | ||
Bx_d, | ||
conv_state, | ||
conv_weights, | ||
self.conv.bias, | ||
activation=None, | ||
conv_state_indices=state_indices_tensor_d) | ||
C_d = C_d.view(num_decodes, -1) | ||
y = C_d * Bx | ||
conv_output_list.append(y.view(num_decodes, -1)) | ||
|
||
# Merge prefill and decode outputs before passing to gated MLP | ||
hidden_states = torch.vstack(conv_output_list) | ||
|
||
# Final linear projection | ||
contextualized_states, _ = self.out_proj(hidden_states) | ||
|
||
return contextualized_states | ||
|
||
|
||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: | ||
paulpak58 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
world_size = get_tensor_model_parallel_world_size() | ||
# contiguous along 'dim' axis | ||
conv_state_shape = ( | ||
self.L_cache - 1, | ||
divide(self.conv_dim, world_size), | ||
) | ||
return (conv_state_shape,) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possibly a cleaner solution than this, but this works.