Skip to content

LFM2 #20797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

LFM2 #20797

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3a50223
[cmake] ignore nvToolsExt for cuda-12.9
paulpak58 Jul 10, 2025
6fd86d2
[model_executor][models] LFM2 architecture
paulpak58 Jul 10, 2025
aaf7df1
[configs] use layer_types from huggingface hybrids >= 4.54.0.dev0
paulpak58 Jul 10, 2025
6c80caf
[model_runner][v1] ShortConvSpec for ShortConv layers; compatibility …
paulpak58 Jul 11, 2025
d17c95f
[configs] need to detect full_attention key in layer_types for transf…
paulpak58 Jul 11, 2025
1bc8835
[layers][conv] update ShortConv layer to be compatible with triton ca…
paulpak58 Jul 11, 2025
e550362
[transformers][ovis] tmp: AIMv2Config doesn't need to be registered o…
paulpak58 Jul 11, 2025
05af65a
[models][lfm2] LFM2->Lfm2 to match config
paulpak58 Jul 11, 2025
40d81e9
[merge] upstream @ 5bac61362b6718b90e708e7b212e7fcbe7d59fa3
paulpak58 Jul 15, 2025
7241660
[v1][cache] add support for conv cache shapes
paulpak58 Jul 15, 2025
3d3be6a
[v1][config] generalize HybridAttentionMambaModelConfig
paulpak58 Jul 15, 2025
b2447dd
[merge] upstream main @ 19c863068b2d70a452bde25318dbcf04f274ad19
paulpak58 Jul 15, 2025
46902dc
[layer][conv] update conv metadata in causal_conv1d
paulpak58 Jul 15, 2025
260e3fe
[misc] format + cleanup
paulpak58 Jul 15, 2025
1dff6e1
[layers][conv] fix minor discprencies in decode conv
paulpak58 Jul 16, 2025
30621b4
[merge] upstream @ a0f8a7964694a6077689b242b5eca95de392d4bb
paulpak58 Jul 16, 2025
9c3edab
[layers][conv] fix ordering of prefill/decode tokens in conv layer
paulpak58 Jul 16, 2025
63cd12b
[tests] register LFM2 in test models
paulpak58 Jul 16, 2025
9af96d9
[tests][hybrid] include LFM2 in V1 Hybrids + include unsupported V1 a…
paulpak58 Jul 16, 2025
7577e89
[docs] update supported_models + v1 guide
paulpak58 Jul 16, 2025
1ff0c89
[misc] fix pre-commit checks
paulpak58 Jul 16, 2025
b425c0d
[model_executor] remap mamba V1 utils to static_cache + cleanup
paulpak58 Jul 16, 2025
80a2f3a
[misc] minor: fix format
paulpak58 Jul 16, 2025
5f2c6c8
[merge] upstream @ 9fb2d22032cee577a189f8c4cddd88a3c190cb0c
paulpak58 Jul 17, 2025
cbc6ba3
[merge] upstream @ ed8cbfedf84f1b1fc1d3eadf3622d9903e906cb0
paulpak58 Jul 18, 2025
7318322
[merge] upstream @ 881e3cbe3b3cef5d6fc50ca0c19e30a9dd11c452
paulpak58 Jul 19, 2025
d8b7170
[models][lfm2] torch compile support
paulpak58 Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ endif()
#
find_package(Torch REQUIRED)

#
# Ignore nvToolsExt for cuda-12.9
#
if (NOT TARGET CUDA::nvToolsExt)
add_library(CUDA::nvToolsExt INTERFACE IMPORTED)
endif()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly a cleaner solution than this, but this works.

# Supported NVIDIA architectures.
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
Expand Down
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ Specified using `--task generate`.
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
Expand Down
4 changes: 4 additions & 0 deletions docs/usage/v1_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ Models that combine Mamba-2 layers with standard attention layers are also suppo
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
backend in V1.

Hybrid models that share similar sub-components to Mamba2 layers, e.g. ShortConv layers in LFM2, are also supported.
As above, they also require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
backend in V1.

#### Encoder-Decoder Models

Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`)
Expand Down
60 changes: 41 additions & 19 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B"
]

HF_UNSUPPORTED_MODELS = [
Expand All @@ -53,17 +54,21 @@
]

V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",
"nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct", "nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B"
]

# Avoid OOM
MAX_NUM_SEQS = 4

# To be removed once implemented
V1_HYBRID_UNSUPPORTED_ARGS = {
"enforce_eager": True,
"enable_prefix_caching": False,
}


@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
Expand Down Expand Up @@ -104,7 +109,7 @@ def test_models(
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
Expand Down Expand Up @@ -182,13 +187,15 @@ def test_chunked_prefill(

with vllm_runner(model,
enable_chunked_prefill=True,
**V1_HYBRID_UNSUPPORTED_ARGS,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens, num_logprobs)

with vllm_runner(model,
enable_chunked_prefill=False,
**V1_HYBRID_UNSUPPORTED_ARGS,
max_num_seqs=max_num_seqs) as vllm_model:
non_chunked = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
Expand Down Expand Up @@ -229,6 +236,7 @@ def test_chunked_prefill_with_parallel_sampling(
# forces prefill chunks with decoding
max_num_batched_tokens=MAX_NUM_SEQS * 3,
max_num_seqs=MAX_NUM_SEQS,
**V1_HYBRID_UNSUPPORTED_ARGS,
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

Expand All @@ -253,7 +261,7 @@ def test_mamba_cache_cg_padding(
example_prompts.append(example_prompts[0])

try:
with vllm_runner(model) as vllm_model:
with vllm_runner(model, **V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
Expand All @@ -273,7 +281,9 @@ def test_models_preemption_recompute(
"""
Tests that outputs are identical with and w/o preemptions (recompute).
"""
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
scheduler = vllm_model.model.llm_engine.scheduler[0]
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
Expand Down Expand Up @@ -306,7 +316,9 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
a single step.
"""
try:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Hybrid inner state wasn't cleaned up properly between"
Expand All @@ -326,7 +338,9 @@ def test_state_cleanup(
If its not cleaned, an error would be expected.
"""
try:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
Expand All @@ -342,13 +356,17 @@ def test_multistep_correctness(
model: str,
max_tokens: int,
) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
with vllm_runner(model,
num_scheduler_steps=8,
max_num_seqs=2,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)

with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
with vllm_runner(model,
num_scheduler_steps=1,
max_num_seqs=2,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)

Expand All @@ -371,13 +389,17 @@ def test_distributed_correctness(
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model, tensor_parallel_size=1,
max_num_seqs=2) as vllm_model:
with vllm_runner(model,
tensor_parallel_size=1,
max_num_seqs=2,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, tensor_parallel_size=2,
max_num_seqs=2) as vllm_model:
with vllm_runner(model,
tensor_parallel_size=2,
max_num_seqs=2,
**V1_HYBRID_UNSUPPORTED_ARGS) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def check_available_online(
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
min_transformers_version="4.54"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
Expand Down
30 changes: 23 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,13 @@ def get_num_layers_by_block_type(
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)

# Hybrid models in transformers >= 4.54.0.dev0
# populate a `layer_types` attribute
if layers_block_type_value is None:
layers_block_type_value = getattr(self.hf_text_config,
"layer_types", None)

if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
Expand All @@ -1388,8 +1395,14 @@ def get_num_layers_by_block_type(
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])

# Support with hybrid transformers configs >= 4.54.0.dev0
if attn_block_type:
return sum(t in ("full_attention", "attention")
for t in layers_block_type_value[start:end])
else:
return sum(t == block_type.value
for t in layers_block_type_value[start:end])

# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
Expand Down Expand Up @@ -1634,9 +1647,10 @@ class CacheConfig:
checkpoint if available. Otherwise, the scales will default to 1.0."""
cpu_kvcache_space_bytes: Optional[int] = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: Optional[int] = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
static_cache_page_size_padded: Optional[int] = None
""" Optional override for static cache page size; used by hybrid static
cache (e.g. mamba, short-conv) / attention models to ensure exact alignment
with attention page size."""

# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
Expand Down Expand Up @@ -4313,6 +4327,7 @@ def set_splitting_ops_for_v1(self):
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
"vllm.short_conv",
]


Expand Down Expand Up @@ -4789,13 +4804,14 @@ def try_verify_and_update_config(self):
return

from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
MODELS_CONFIG_MAP, HybridAttentionStaticCacheModelConfig)
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)

if self.model_config.is_hybrid:
HybridAttentionMambaModelConfig.verify_and_update_config(self)
HybridAttentionStaticCacheModelConfig.verify_and_update_config(
self)

if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
Expand Down
Loading