23
23
from typing import Any , List , Optional , Tuple , TypeVar , Union
24
24
25
25
import numpy as np
26
- import pytest
27
26
import torch
28
- from modelscope import snapshot_download # type: ignore[import-untyped]
29
27
from PIL import Image
30
28
from torch import nn
31
29
from transformers import (AutoConfig , AutoModelForCausalLM , AutoTokenizer ,
32
30
BatchEncoding , BatchFeature )
33
31
from transformers .models .auto .auto_factory import _BaseAutoModelClass
34
32
from vllm import LLM , SamplingParams
35
33
from vllm .config import TaskOption , _get_and_verify_dtype
36
- from vllm .inputs import ExplicitEncoderDecoderPrompt , TextPrompt , TokensPrompt
34
+ from vllm .inputs import TextPrompt
37
35
from vllm .outputs import RequestOutput
38
- from vllm .sampling_params import BeamSearchParams
39
36
from vllm .transformers_utils .utils import maybe_model_redirect
40
- from vllm .utils import is_list_of
41
37
42
- from tests .e2e .model_utils import (PROMPT_TEMPLATES , TokensTextLogprobs ,
38
+ from tests .e2e .model_utils import (TokensTextLogprobs ,
43
39
TokensTextLogprobsPromptLogprobs )
44
40
# TODO: remove this part after the patch merged into vllm, if
45
41
# we not explicitly patch here, some of them might be effectiveless
62
58
PromptVideoInput = _PromptMultiModalInput [np .ndarray ]
63
59
64
60
_TEST_DIR = os .path .dirname (__file__ )
65
- _TEST_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "example.txt" )]
66
61
67
62
68
63
def cleanup_dist_env_and_memory (shutdown_ray : bool = False ):
@@ -95,7 +90,7 @@ def __init__(
95
90
block_size : int = 16 ,
96
91
enable_chunked_prefill : bool = False ,
97
92
swap_space : int = 4 ,
98
- enforce_eager : Optional [bool ] = True ,
93
+ enforce_eager : Optional [bool ] = False ,
99
94
quantization : Optional [str ] = None ,
100
95
** kwargs ,
101
96
) -> None :
@@ -220,26 +215,6 @@ def generate_w_logprobs(
220
215
if sampling_params .prompt_logprobs is None else
221
216
toks_str_logsprobs_prompt_logprobs )
222
217
223
- def generate_encoder_decoder_w_logprobs (
224
- self ,
225
- encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [str , str ]],
226
- sampling_params : SamplingParams ,
227
- ) -> Union [List [TokensTextLogprobs ],
228
- List [TokensTextLogprobsPromptLogprobs ]]:
229
- '''
230
- Logprobs generation for vLLM encoder/decoder models
231
- '''
232
-
233
- assert sampling_params .logprobs is not None
234
- req_outputs = self .model .generate (encoder_decoder_prompts ,
235
- sampling_params = sampling_params )
236
- toks_str_logsprobs_prompt_logprobs = (
237
- self ._final_steps_generate_w_logprobs (req_outputs ))
238
- # Omit prompt logprobs if not required by sampling params
239
- return ([x [0 :- 1 ] for x in toks_str_logsprobs_prompt_logprobs ]
240
- if sampling_params .prompt_logprobs is None else
241
- toks_str_logsprobs_prompt_logprobs )
242
-
243
218
def generate_greedy (
244
219
self ,
245
220
prompts : List [str ],
@@ -284,53 +259,6 @@ def generate_greedy_logprobs(
284
259
audios = audios ,
285
260
videos = videos )
286
261
287
- def generate_encoder_decoder_greedy_logprobs (
288
- self ,
289
- encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [str , str ]],
290
- max_tokens : int ,
291
- num_logprobs : int ,
292
- num_prompt_logprobs : Optional [int ] = None ,
293
- ) -> Union [List [TokensTextLogprobs ],
294
- List [TokensTextLogprobsPromptLogprobs ]]:
295
- greedy_logprobs_params = SamplingParams (
296
- temperature = 0.0 ,
297
- max_tokens = max_tokens ,
298
- logprobs = num_logprobs ,
299
- prompt_logprobs = (num_prompt_logprobs ),
300
- )
301
- '''
302
- Greedy logprobs generation for vLLM encoder/decoder models
303
- '''
304
-
305
- return self .generate_encoder_decoder_w_logprobs (
306
- encoder_decoder_prompts , greedy_logprobs_params )
307
-
308
- def generate_beam_search (
309
- self ,
310
- prompts : Union [List [str ], List [List [int ]]],
311
- beam_width : int ,
312
- max_tokens : int ,
313
- ) -> List [Tuple [List [List [int ]], List [str ]]]:
314
- if is_list_of (prompts , str , check = "all" ):
315
- prompts = [TextPrompt (prompt = prompt ) for prompt in prompts ]
316
- else :
317
- prompts = [
318
- TokensPrompt (prompt_token_ids = tokens ) for tokens in prompts
319
- ]
320
- outputs = self .model .beam_search (
321
- prompts ,
322
- BeamSearchParams (beam_width = beam_width , max_tokens = max_tokens ))
323
- returned_outputs = []
324
- for output in outputs :
325
- token_ids = [x .tokens for x in output .sequences ]
326
- texts = [x .text for x in output .sequences ]
327
- returned_outputs .append ((token_ids , texts ))
328
- return returned_outputs
329
-
330
- def classify (self , prompts : List [str ]) -> List [List [float ]]:
331
- req_outputs = self .model .classify (prompts )
332
- return [req_output .outputs .probs for req_output in req_outputs ]
333
-
334
262
def encode (
335
263
self ,
336
264
prompts : List [str ],
@@ -346,50 +274,6 @@ def encode(
346
274
req_outputs = self .model .embed (inputs )
347
275
return [req_output .outputs .embedding for req_output in req_outputs ]
348
276
349
- def score (
350
- self ,
351
- text_1 : Union [str , List [str ]],
352
- text_2 : Union [str , List [str ]],
353
- ) -> List [float ]:
354
- req_outputs = self .model .score (text_1 , text_2 )
355
- return [req_output .outputs .score for req_output in req_outputs ]
356
-
357
- def __enter__ (self ):
358
- return self
359
-
360
- def __exit__ (self , exc_type , exc_value , traceback ):
361
- del self .model
362
- cleanup_dist_env_and_memory ()
363
-
364
-
365
- @pytest .fixture (scope = "session" )
366
- def vllm_runner ():
367
- return VllmRunner
368
-
369
-
370
- @pytest .fixture (params = list (PROMPT_TEMPLATES .keys ()))
371
- def prompt_template (request ):
372
- return PROMPT_TEMPLATES [request .param ]
373
-
374
-
375
- def _read_prompts (filename : str ) -> list [str ]:
376
- with open (filename ) as f :
377
- prompts = f .readlines ()
378
- return prompts
379
-
380
-
381
- @pytest .fixture
382
- def example_prompts () -> list [str ]:
383
- prompts = []
384
- for filename in _TEST_PROMPTS :
385
- prompts += _read_prompts (filename )
386
- return prompts
387
-
388
-
389
- @pytest .fixture (scope = "session" )
390
- def ilama_lora_files ():
391
- return snapshot_download (repo_id = "vllm-ascend/ilama-text2sql-spider" )
392
-
393
277
394
278
class HfRunner :
395
279
@@ -502,18 +386,9 @@ def __init__(
502
386
if skip_tokenizer_init :
503
387
self .tokenizer = self .processor .tokenizer
504
388
505
- def encode (self , prompts : list [str ], * args ,
506
- ** kwargs ) -> list [list [torch .Tensor ]]:
507
- return self .model .encode (prompts , * args , ** kwargs )
508
-
509
389
def __enter__ (self ):
510
390
return self
511
391
512
392
def __exit__ (self , exc_type , exc_value , traceback ):
513
393
del self .model
514
394
cleanup_dist_env_and_memory ()
515
-
516
-
517
- @pytest .fixture (scope = "session" )
518
- def hf_runner ():
519
- return HfRunner
0 commit comments