Skip to content

Commit 7ab2b9e

Browse files
committed
add an integration test
1 parent d8017cb commit 7ab2b9e

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/models/test_mamba.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
2+
3+
Run `pytest tests/models/test_mamba.py`.
4+
"""
5+
import pytest
6+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
7+
import torch
8+
9+
from .utils import check_outputs_equal
10+
11+
MODELS = [
12+
"state-spaces/mamba-370m-hf",
13+
]
14+
15+
# Use lower-level interfaces to create this greedy generator, as mamba will
16+
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
17+
def generate_greedy(model_name, example_prompts, max_tokens):
18+
# Create a text generation pipeline
19+
tokenizer = AutoTokenizer.from_pretrained(model_name)
20+
model = AutoModelForCausalLM.from_pretrained(model_name)
21+
22+
generator = TextGenerationPipeline(model=model, tokenizer=tokenizer,
23+
device=torch.cuda.current_device()
24+
if torch.cuda.is_available() else -1)
25+
26+
# Generate texts from the prompts
27+
outputs = []
28+
for prompt in example_prompts:
29+
# Tokenize the input prompt with truncation
30+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
31+
input_ids = inputs["input_ids"].to(model.device)
32+
33+
# Generate text using the model's generate method directly
34+
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
35+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
36+
37+
outputs.append((generated_ids[0].tolist(), generated_text))
38+
39+
return outputs
40+
41+
@pytest.mark.parametrize("model", MODELS)
42+
@pytest.mark.parametrize("dtype", ["float"])
43+
@pytest.mark.parametrize("max_tokens", [96])
44+
def test_models(
45+
hf_runner,
46+
vllm_runner,
47+
example_prompts,
48+
model: str,
49+
dtype: str,
50+
max_tokens: int,
51+
) -> None:
52+
# To pass the small model tests, we need full precision.
53+
assert dtype == "float"
54+
55+
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
56+
57+
with vllm_runner(model, dtype=dtype) as vllm_model:
58+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
59+
60+
check_outputs_equal(
61+
outputs_0_lst=hf_outputs,
62+
outputs_1_lst=vllm_outputs,
63+
name_0="hf",
64+
name_1="vllm",
65+
)
66+
67+
68+
@pytest.mark.parametrize("model", MODELS)
69+
@pytest.mark.parametrize("dtype", ["float"])
70+
def test_model_print(
71+
vllm_runner,
72+
model: str,
73+
dtype: str,
74+
) -> None:
75+
with vllm_runner(model, dtype=dtype) as vllm_model:
76+
# This test is for verifying whether the model's extra_repr
77+
# can be printed correctly.
78+
print(vllm_model.model.llm_engine.model_executor.driver_worker.
79+
model_runner.model)

0 commit comments

Comments
 (0)