Skip to content

Commit c319a21

Browse files
committed
lint
1 parent 7ab2b9e commit c319a21

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

tests/models/test_mamba.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,39 @@
33
Run `pytest tests/models/test_mamba.py`.
44
"""
55
import pytest
6-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
7-
import torch
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
87

98
from .utils import check_outputs_equal
109

1110
MODELS = [
1211
"state-spaces/mamba-370m-hf",
1312
]
1413

14+
1515
# Use lower-level interfaces to create this greedy generator, as mamba will
1616
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
1717
def generate_greedy(model_name, example_prompts, max_tokens):
1818
# Create a text generation pipeline
1919
tokenizer = AutoTokenizer.from_pretrained(model_name)
2020
model = AutoModelForCausalLM.from_pretrained(model_name)
2121

22-
generator = TextGenerationPipeline(model=model, tokenizer=tokenizer,
23-
device=torch.cuda.current_device()
24-
if torch.cuda.is_available() else -1)
25-
2622
# Generate texts from the prompts
2723
outputs = []
2824
for prompt in example_prompts:
2925
# Tokenize the input prompt with truncation
3026
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
3127
input_ids = inputs["input_ids"].to(model.device)
32-
28+
3329
# Generate text using the model's generate method directly
3430
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
35-
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
31+
generated_text = tokenizer.decode(generated_ids[0],
32+
skip_special_tokens=True)
3633

3734
outputs.append((generated_ids[0].tolist(), generated_text))
3835

3936
return outputs
4037

38+
4139
@pytest.mark.parametrize("model", MODELS)
4240
@pytest.mark.parametrize("dtype", ["float"])
4341
@pytest.mark.parametrize("max_tokens", [96])

0 commit comments

Comments
 (0)