Skip to content

[Model] Support Mamba #6484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ce630ea
WiP adding support for Mamba
tlrmchlsmth Jul 8, 2024
6c59b06
wip
tlrmchlsmth Jul 9, 2024
eb9bf34
WIP -- runs through. Generates tokens. Bad tokens.
tlrmchlsmth Jul 10, 2024
320f79b
Good output for mamba-370m
tlrmchlsmth Jul 15, 2024
5ab6622
wip
tlrmchlsmth Jul 16, 2024
71173a0
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
25b54d9
cleanup
tlrmchlsmth Jul 16, 2024
ebc12f1
Rename embedding block space manager
tlrmchlsmth Jul 16, 2024
ac60374
cleanup
tlrmchlsmth Jul 16, 2024
adb6713
remove file
tlrmchlsmth Jul 16, 2024
b733a84
format
tlrmchlsmth Jul 16, 2024
fb846ce
apply fix from #6214
tlrmchlsmth Jul 16, 2024
09b1495
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
d8017cb
fixes from 6425
tlrmchlsmth Jul 16, 2024
7ab2b9e
add an integration test
tlrmchlsmth Jul 23, 2024
c319a21
lint
tlrmchlsmth Jul 23, 2024
3374d8f
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 31, 2024
76022d3
fixup
tlrmchlsmth Jul 31, 2024
9ffc057
backend selector changes
tlrmchlsmth Jul 31, 2024
65d7e22
lint
tlrmchlsmth Jul 31, 2024
f14648e
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Aug 20, 2024
e76a617
Factor out mamba cache from jamba.py, and fixes
tlrmchlsmth Aug 20, 2024
b9723fe
Fix mamba cache initialized bool. format and renames
tlrmchlsmth Aug 21, 2024
b2a8cd8
Refactor mamba to use the MambaCacheManager
tlrmchlsmth Aug 21, 2024
9ba8734
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 28, 2024
f87a8e2
fixes
tlrmchlsmth Aug 29, 2024
06b146e
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 29, 2024
8e16aca
Update to use kernels from #7651
tlrmchlsmth Aug 29, 2024
120b761
some cruft
tlrmchlsmth Aug 29, 2024
698f666
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 13, 2024
a5bd7d2
Move test_mamba.py (for #7820)
tlrmchlsmth Sep 13, 2024
6546bd9
fixes
tlrmchlsmth Sep 13, 2024
f42af9b
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 23, 2024
85a8378
Review comments
tlrmchlsmth Sep 24, 2024
80e3c77
cache attention free
tlrmchlsmth Sep 24, 2024
184e808
fixup
tlrmchlsmth Sep 24, 2024
05d6aab
fixup
tlrmchlsmth Sep 24, 2024
4ebd4cc
missed two
tlrmchlsmth Sep 24, 2024
ca3788e
Remove is_attention_free from SchedulerConfig
tlrmchlsmth Sep 24, 2024
c67a650
default `is_attention_free` for unit tests
tlrmchlsmth Sep 25, 2024
9e2edf6
Fix attention selector tests
tlrmchlsmth Sep 25, 2024
f41b474
merge main, support chunked prefill, more tests
tlrmchlsmth Sep 30, 2024
7ef3c68
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
8729b43
Review comments
tlrmchlsmth Oct 10, 2024
5fb01c4
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
16d3f1d
format
tlrmchlsmth Oct 10, 2024
4b21a08
Fix supported_models.rst
tlrmchlsmth Oct 10, 2024
ec8ef04
jambafix
tlrmchlsmth Oct 10, 2024
49e1f3c
fix softfail on cpu tests
tlrmchlsmth Oct 11, 2024
e80b82a
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 11, 2024
609e9fb
fix for #9233
tlrmchlsmth Oct 11, 2024
93129e5
format
tlrmchlsmth Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
False)
assert backend.name == name


Expand All @@ -46,37 +46,42 @@ def test_flash_attn(monkeypatch):

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
assert backend.name != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
which_attn_to_use(16, None, torch.float16, None, 16, False)
295 changes: 295 additions & 0 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.

Run `pytest tests/models/test_mamba.py`.
"""
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

MODELS = ["state-spaces/mamba-130m-hf"]


# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
# Tokenize the input prompt with truncation
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)

# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)

outputs.append((generated_ids[0].tolist(), generated_text))

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_outputs = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])

batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)

check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# Tests chunked prefill in conjunction with n>1. In this case, prefill is
# populated with decoding tokens and we test that it doesn't fail.
# This test might fail if cache is not allocated correctly for n > 1
# decoding steps inside a chunked prefill forward pass (where we have both
# prefill and decode together )
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int,
chunked_prefill_token_size: int) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size

non_chunked = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)

check_outputs_equal(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
# using example_prompts index 1 instead of 0 since with 0 the
# logprobs get really close and the test doesn't pass
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]

check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)

vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum Mamba block capacity.
# This could generally happen due to the fact that Mamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
Loading
Loading