|
| 1 | +import torch |
| 2 | +from transformers import (AutoModelForCausalLM, AutoTokenizer, |
| 3 | + PreTrainedTokenizer) |
| 4 | +from vllm import LLM |
| 5 | + |
| 6 | + |
| 7 | +def init_tokenizer_and_llm(model_name: str): |
| 8 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 9 | + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) |
| 10 | + embedding_layer = transformers_model.get_input_embeddings() |
| 11 | + llm = LLM(model=model_name, enable_prompt_embeds=True) |
| 12 | + return tokenizer, embedding_layer, llm |
| 13 | + |
| 14 | + |
| 15 | +def get_prompt_embeds(chat: list[dict[str, |
| 16 | + str]], tokenizer: PreTrainedTokenizer, |
| 17 | + embedding_layer: torch.nn.Module): |
| 18 | + token_ids = tokenizer.apply_chat_template(chat, |
| 19 | + add_generation_prompt=True, |
| 20 | + return_tensors='pt') |
| 21 | + prompt_embeds = embedding_layer(token_ids).squeeze(0) |
| 22 | + return prompt_embeds |
| 23 | + |
| 24 | + |
| 25 | +def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, |
| 26 | + embedding_layer: torch.nn.Module): |
| 27 | + chat = [{ |
| 28 | + "role": "user", |
| 29 | + "content": "Please tell me about the capital of France." |
| 30 | + }] |
| 31 | + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) |
| 32 | + |
| 33 | + outputs = llm.generate({ |
| 34 | + "prompt_embeds": prompt_embeds, |
| 35 | + }) |
| 36 | + |
| 37 | + print("\n[Single Inference Output]") |
| 38 | + print("-" * 30) |
| 39 | + for o in outputs: |
| 40 | + print(o.outputs[0].text) |
| 41 | + print("-" * 30) |
| 42 | + |
| 43 | + |
| 44 | +def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, |
| 45 | + embedding_layer: torch.nn.Module): |
| 46 | + chats = [[{ |
| 47 | + "role": "user", |
| 48 | + "content": "Please tell me about the capital of France." |
| 49 | + }], |
| 50 | + [{ |
| 51 | + "role": "user", |
| 52 | + "content": "When is the day longest during the year?" |
| 53 | + }], |
| 54 | + [{ |
| 55 | + "role": "user", |
| 56 | + "content": "Where is bigger, the moon or the sun?" |
| 57 | + }]] |
| 58 | + |
| 59 | + prompt_embeds_list = [ |
| 60 | + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats |
| 61 | + ] |
| 62 | + |
| 63 | + outputs = llm.generate([{ |
| 64 | + "prompt_embeds": embeds |
| 65 | + } for embeds in prompt_embeds_list]) |
| 66 | + |
| 67 | + print("\n[Batch Inference Outputs]") |
| 68 | + print("-" * 30) |
| 69 | + for i, o in enumerate(outputs): |
| 70 | + print(f"Q{i+1}: {chats[i][0]['content']}") |
| 71 | + print(f"A{i+1}: {o.outputs[0].text}\n") |
| 72 | + print("-" * 30) |
| 73 | + |
| 74 | + |
| 75 | +def main(): |
| 76 | + model_name = "meta-llama/Llama-3.2-1B-Instruct" |
| 77 | + tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name) |
| 78 | + single_prompt_inference(llm, tokenizer, embedding_layer) |
| 79 | + batch_prompt_inference(llm, tokenizer, embedding_layer) |
| 80 | + |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + main() |
0 commit comments