Skip to content

LFM2 #20797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

LFM2 #20797

Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3a50223
[cmake] ignore nvToolsExt for cuda-12.9
paulpak58 Jul 10, 2025
6fd86d2
[model_executor][models] LFM2 architecture
paulpak58 Jul 10, 2025
aaf7df1
[configs] use layer_types from huggingface hybrids >= 4.54.0.dev0
paulpak58 Jul 10, 2025
6c80caf
[model_runner][v1] ShortConvSpec for ShortConv layers; compatibility …
paulpak58 Jul 11, 2025
d17c95f
[configs] need to detect full_attention key in layer_types for transf…
paulpak58 Jul 11, 2025
1bc8835
[layers][conv] update ShortConv layer to be compatible with triton ca…
paulpak58 Jul 11, 2025
e550362
[transformers][ovis] tmp: AIMv2Config doesn't need to be registered o…
paulpak58 Jul 11, 2025
05af65a
[models][lfm2] LFM2->Lfm2 to match config
paulpak58 Jul 11, 2025
40d81e9
[merge] upstream @ 5bac61362b6718b90e708e7b212e7fcbe7d59fa3
paulpak58 Jul 15, 2025
7241660
[v1][cache] add support for conv cache shapes
paulpak58 Jul 15, 2025
3d3be6a
[v1][config] generalize HybridAttentionMambaModelConfig
paulpak58 Jul 15, 2025
b2447dd
[merge] upstream main @ 19c863068b2d70a452bde25318dbcf04f274ad19
paulpak58 Jul 15, 2025
46902dc
[layer][conv] update conv metadata in causal_conv1d
paulpak58 Jul 15, 2025
260e3fe
[misc] format + cleanup
paulpak58 Jul 15, 2025
1dff6e1
[layers][conv] fix minor discprencies in decode conv
paulpak58 Jul 16, 2025
30621b4
[merge] upstream @ a0f8a7964694a6077689b242b5eca95de392d4bb
paulpak58 Jul 16, 2025
9c3edab
[layers][conv] fix ordering of prefill/decode tokens in conv layer
paulpak58 Jul 16, 2025
63cd12b
[tests] register LFM2 in test models
paulpak58 Jul 16, 2025
9af96d9
[tests][hybrid] include LFM2 in V1 Hybrids + include unsupported V1 a…
paulpak58 Jul 16, 2025
7577e89
[docs] update supported_models + v1 guide
paulpak58 Jul 16, 2025
1ff0c89
[misc] fix pre-commit checks
paulpak58 Jul 16, 2025
b425c0d
[model_executor] remap mamba V1 utils to static_cache + cleanup
paulpak58 Jul 16, 2025
80a2f3a
[misc] minor: fix format
paulpak58 Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ endif()
#
find_package(Torch REQUIRED)

#
# Ignore nvToolsExt for cuda-12.9
#
if (NOT TARGET CUDA::nvToolsExt)
add_library(CUDA::nvToolsExt INTERFACE IMPORTED)
endif()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly a cleaner solution than this, but this works.

# Supported NVIDIA architectures.
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
Expand Down
22 changes: 18 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,13 @@ def get_num_layers_by_block_type(
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)

# Hybrid models in transformers >= 4.54.0.dev0
# populate a `layer_types` attribute
if layers_block_type_value is None:
layers_block_type_value = getattr(self.hf_text_config,
"layer_types", None)

if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
Expand All @@ -1382,8 +1389,14 @@ def get_num_layers_by_block_type(
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])

# Support with hybrid transformers configs >= 4.54.0.dev0
if attn_block_type:
return sum(t in ("full_attention", "attention")
for t in layers_block_type_value[start:end])
else:
return sum(t == block_type.value
for t in layers_block_type_value[start:end])

# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
Expand Down Expand Up @@ -4820,13 +4833,14 @@ def try_verify_and_update_config(self):
return

from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
MODELS_CONFIG_MAP, HybridAttentionStaticCacheModelConfig)
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)

if self.model_config.is_hybrid:
HybridAttentionMambaModelConfig.verify_and_update_config(self)
HybridAttentionStaticCacheModelConfig.verify_and_update_config(
self)

if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
Expand Down
243 changes: 243 additions & 0 deletions vllm/model_executor/layers/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@

from typing import Any, Optional

import torch
import torch.nn as nn

from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.models.conv_cache import ConvCacheParams
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata


@CustomOp.register("short_conv")
class ShortConv(CustomOp):

def __init__(self, config, dim: int, layer_idx: int, prefix: str = ""):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias

self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)

if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state,)
self.kv_cache = [(torch.tensor([]))]

# For compatibility with MambaSpec utils
self.chunk_size = 1
self.prefix = prefix

def forward_native(self, hidden_states: torch.Tensor,
conv_cache_params: ConvCacheParams) -> torch.Tensor:
pass

def forward_cuda(
self,
hidden_states: torch.Tensor,
conv_cache_params: ConvCacheParams,
conv_metadata: Mamba2Metadata,
) -> torch.Tensor:
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: Optional[AttentionMetadata] = get_forward_context().attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
# prep_initial_states = attn_metadata.prep_initial_states
# chunk_size = attn_metadata.chunk_size
# seq_idx_p = attn_metadata.seq_idx
# chunk_indices_p = attn_metadata.chunk_indices
# chunk_offsets_p = attn_metadata.chunk_offsets
else:
conv_state = conv_cache_params.conv_state
state_indices_tensor = conv_cache_params.state_indices_tensor
has_initial_states_p = conv_metadata.has_initial_states
# prep_initial_states = conv_metadata.prep_initial_states
# chunk_size = conv_metadata.chunk_size
# seq_idx_p = conv_metadata.seq_idx
# chunk_indices_p = conv_metadata.chunk_indices
# chunk_offsets_p = conv_metadata.chunk_offsets

BCx, _ = self.in_proj(hidden_states)

B, C, x = BCx.chunk(3, dim=-1)

conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))

if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states

Check failure on line 129 in vllm/model_executor/layers/conv.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "num_prefills" [union-attr]

num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0

# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
B_d, B_p = torch.split(
B,
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C,
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
B_p, B_d = torch.split(
B,
[num_prefill_tokens, num_decodes],
dim=0,
)
C_p, C_d = torch.split(
C,
[num_prefill_tokens, num_decodes],
dim=0,
)
x_p, x_d = torch.split(
x,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
conv_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)

conv_output_list = []

if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(
Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=conv_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]

C_p = C_p.view(1, num_prefill_tokens, -1)
y = C_p * Bx
conv_output_list.append(y.view(num_prefill_tokens, -1))

if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
C_d = C_d.view(num_decodes, -1)
y = C_d * Bx
conv_output_list.append(y.view(num_decodes, -1))

# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)

# Final linear projection
contextualized_states, _ = self.out_proj(hidden_states)

return contextualized_states


def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
# contiguous along 'dim' axis
conv_state_shape = (
self.L_cache - 1,
divide(self.conv_dim, world_size),
)
return (conv_state_shape,)
Loading
Loading