diff --git a/.gitignore b/.gitignore index 54a5a08..c7c1fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,7 @@ cython_debug/ .pypirc .vscode/ app/.gradio/ -test* \ No newline at end of file +test* + + +t.ipynb \ No newline at end of file diff --git a/README.md b/README.md index 6dded76..1748fec 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,7 @@ Evaluated on `ananyarn/Algorithm_and_Python_Source_Code`. ### 2.1 Install ```bash conda create -n umbrella python=3.10 +conda activate umbrella bash install.sh ``` ### 2.2 CLI Chatbot diff --git a/configs/chat_config_ar.json b/configs/chat_config_ar.json new file mode 100644 index 0000000..450232e --- /dev/null +++ b/configs/chat_config_ar.json @@ -0,0 +1,15 @@ +{ + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "offload": false, + "max_length": 2048, + "num_cache_layers": 0, + "generation_length": 256, + "max_turns": 16, + "topk": 32, + "temperature": 0.6, + "topp": 0.9, + "repetition_penalty": 1.05, + "exit_layer":16, + "engine": "ar", + "template": "meta-llama3" +} diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 533206d..c9071c7 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -1,7 +1,9 @@ from transformers import AutoConfig import torch import flashinfer +from flash_attn import flash_attn_with_kvcache import math + class KV_Cache: def __init__(self, @@ -16,6 +18,7 @@ def __init__(self, self.dtype = dtype self.k_cache = torch.zeros( config.num_hidden_layers, + batch_size, max_length, config.num_key_value_heads, config.hidden_size // config.num_attention_heads, @@ -25,6 +28,7 @@ def __init__(self, self.v_cache = torch.zeros( config.num_hidden_layers, + batch_size, max_length, config.num_key_value_heads, config.hidden_size // config.num_attention_heads, @@ -40,11 +44,11 @@ def __init__(self, def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): - self.k_cache[:,offset:offset + len(indices), :,:] = self.k_cache[:,indices, :,:] - self.v_cache[:,offset:offset + len(indices), :,:] = self.v_cache[:,indices, :,:] + self.k_cache[:,:,offset:offset + len(indices), :,:] = self.k_cache[:,:,indices, :,:] + self.v_cache[:,:,offset:offset + len(indices), :,:] = self.v_cache[:,:,indices, :,:] - self.k_cache[:,offset + len(indices):, :,:] = 0.0 - self.v_cache[:,offset + len(indices):, :,:] = 0.0 + self.k_cache[:,:,offset + len(indices):, :,:] = 0.0 + self.v_cache[:,:,offset + len(indices):, :,:] = 0.0 self.kv_offset = offset + len(indices) @@ -54,33 +58,39 @@ def update_kv_cache(self, new_k_cache :torch.Tensor, new_v_cache :torch.Tensor, layer_idx :int, - storage_ids: torch.LongTensor + storage_ids: torch.LongTensor=None ): - new_kv_len = storage_ids.shape[0] + new_kv_len = new_k_cache.shape[1] # [bsz, seq, num_heads, head_dim] if layer_idx == 0: self.kv_offset += new_kv_len - self.k_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache - self.v_cache[layer_idx][self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache - return self.k_cache[layer_idx][:self.kv_offset], self.v_cache[layer_idx][:self.kv_offset] + self.k_cache[layer_idx][:, self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache + self.v_cache[layer_idx][:, self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache + return self.k_cache[layer_idx][:, :self.kv_offset], self.v_cache[layer_idx][:, :self.kv_offset] def compute_attention(self, query_states :torch.Tensor, key_states :torch.Tensor, value_states :torch.Tensor, layer_idx, - storage_ids :torch.Tensor, - attention_mask :torch.Tensor): + storage_ids :torch.Tensor=None, + attention_mask :torch.Tensor=None): + + key_states, value_states = self.update_kv_cache(key_states, value_states, layer_idx, storage_ids) - key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) - hidden_states = flashinfer.single_prefill_with_kv_cache( - q = query_states[0], - k = key_states, - v = value_states, - kv_layout="NHD", - custom_mask=attention_mask[:,:self.kv_offset], - allow_fp16_qk_reduction=True - ) + if attention_mask is not None: + hidden_states = flashinfer.single_prefill_with_kv_cache( + q = query_states[0], + k = key_states[0], + v = value_states[0], + kv_layout="NHD", + custom_mask=attention_mask[:,:self.kv_offset], + allow_fp16_qk_reduction=True + ) + else: + # do not use attn mask + # print(query_states.shape, key_states.shape, value_states.shape) + hidden_states = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, causal=True) return hidden_states @@ -92,6 +102,9 @@ def clear(self): def set_kv_len(self, kv_len :int): self.kv_offset = kv_len + def get_kv_len(self): + return self.kv_offset + class StaticKV_Cache: diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 20f9483..d032054 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -60,6 +60,9 @@ class AutoModelLM: "meta-llama/Llama-3.3-70B-Instruct": Llama, "meta-llama/Llama-3.1-70B-Instruct": Llama, "meta-llama/Llama-3.1-8B-Instruct": Llama, + + "gradientai/Llama-3-8B-Instruct-Gradient-1048k": Llama, + "meta-llama/Meta-Llama-3-70B-Instruct": Llama, "meta-llama/Meta-Llama-3-8B-Instruct": Llama, "meta-llama/Llama-3.2-1B-Instruct": Llama, diff --git a/umbrella/models/llama.py b/umbrella/models/llama.py index 3c18b46..1c4034b 100644 --- a/umbrella/models/llama.py +++ b/umbrella/models/llama.py @@ -8,6 +8,7 @@ from .base import LLMBase from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph from tqdm import tqdm + class Llama(LLMBase): def __init__(self, model_name: str, @@ -78,8 +79,8 @@ def layer_compute(self, layer_idx :int, hidden_states: torch.FloatTensor, position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): + attention_mask: torch.FloatTensor=None, + storage_ids: torch.LongTensor=None): residual = hidden_states bsz, q_len, _ = hidden_states.size() @@ -118,8 +119,8 @@ def layer_compute(self, def inference(self, input_ids: torch.LongTensor, position_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - storage_ids: torch.LongTensor): + attention_mask: torch.FloatTensor=None, + storage_ids: torch.LongTensor=None): hidden_states = F.embedding(input_ids, self.embed_tokens) for idx in range(self.num_layers): diff --git a/umbrella/speculation/ar_engine.py b/umbrella/speculation/ar_engine.py new file mode 100644 index 0000000..d10ad12 --- /dev/null +++ b/umbrella/speculation/ar_engine.py @@ -0,0 +1,300 @@ +import torch +from ..models import AutoModelLM +from transformers import AutoTokenizer, GenerationConfig +from .speculation_utils import ( +make_causal_mask, +find_first_element_position, +apply_repetition_penalty, +apply_topk, +is_sentence_complete_regex +) +import time +import flashinfer +from ..logging_config import setup_logger +from ..utils import TextColors +from .base import BaseEngine +logger = setup_logger() + +class AREngine: + + def __init__(self, + model_name: str, + dtype=torch.float16, + device :str = 'cuda:0', + **kwargs + ) -> None: + + super().__init__() + + self.model_name = model_name + self.dtype = dtype + self.device = device + + self.max_length = kwargs.pop("max_length", 8192) + self.safe_buffer = kwargs.pop("safe_buffer", 8) + self.temperature = kwargs.pop("temperature", 0.0) + self.topp = kwargs.pop("topp", 0.9) + self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + self.topk = kwargs.pop("topk", 32) + self.num_beams = kwargs.pop("num_beams", 24) + self.offload = kwargs.pop("offload", False) + self.config = kwargs + + def initialize(self): + + self.tokens = torch.zeros(1, self.max_length, device=self.device).long() + self.model = AutoModelLM.from_pretrained( + model_name=self.model_name, offload=self.offload, batch_size=1, + max_length=self.max_length, device=self.device, + dtype=self.dtype) + + self.model.alloc(**self.config) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.vocab_size = self.model.config.vocab_size + self.generation_config = GenerationConfig.from_pretrained(self.model_name) + self.eos_tokens = self.generation_config.eos_token_id if (isinstance(self.generation_config.eos_token_id, list)) else [self.generation_config.eos_token_id] + + def prefill(self, text:str): + input_ids = self.tokenizer.encode(text=text, return_tensors="pt").to(device=self.device) + return self._prefill(input_ids=input_ids) + + def append(self, text:str): + pass + + def get_ctx(self, input_ids: torch.LongTensor): + input_len = input_ids.size(1) + past_len = self.model.kv_cache.get_kv_len() + position_ids = torch.arange(past_len, past_len + input_len, device=self.device, dtype=torch.long).unsqueeze(0).repeat(input_ids.size(0), 1) + return position_ids + + @torch.inference_mode() + def _prefill(self, input_ids:torch.LongTensor): + + prefix_len = input_ids.shape[1] + if prefix_len >= self.max_length - 2 * self.safe_buffer: + return False + + self.tokens[:,:prefix_len].copy_(input_ids) + + logits = self.model.inference(input_ids=self.tokens[:,:prefix_len], position_ids=self.get_ctx(input_ids)) + assert self.model.kv_cache.get_kv_len() == input_ids.shape[-1], f"KV length mismatch, got {self.model.kv_cache.get_kv_len()}, expected {input_ids.shape[-1]}" + + next_token = self.sample_tokens(logits[:, -1, :]) + + self.tokens[:,prefix_len:prefix_len+1] = next_token + + return True + + @torch.inference_mode() + def _append(self, input_ids:torch.LongTensor): + pass + + + def update_generation_args(self, **generation_args): + + self.temperature = generation_args.pop("temperature", self.temperature) + self.topp = generation_args.pop("topp", self.topp) + self.repetition_penalty = generation_args.pop("repetition_penalty", self.repetition_penalty) + self.topk = generation_args.pop("topk", self.topk) + + @torch.inference_mode() + def reset(self): + self.tokens.zero_() + self.model.clear() + + @torch.inference_mode() + def sample_tokens(self, logits): + # logits [bsz, seq, vocab] + if self.temperature < 0.05: + # greedy decoding + sampled_tokens = logits.argmax(dim=-1) + else: + #stochastic decoding + logits = apply_topk(logits, topk=self.topk) + proba = torch.softmax(logits/self.temperature, dim=-1) + proba = flashinfer.sampling.top_p_renorm_prob(proba, self.topp) + sampled_tokens = torch.multinomial(proba, num_samples=1).squeeze(-1) + + return sampled_tokens + + def validate_status(self): + + return self.model.kv_cache.get_kv_len() < self.max_length - self.safe_buffer + + @torch.inference_mode() + def generate(self, **api_args): + + self.update_generation_args(**api_args) + input_ids = api_args.get("input_ids", None) + max_new_tokens = api_args.get("max_new_tokens", 128) + + if input_ids is None: + context = api_args.get("context", None) + if context is None or len(context) == 0 or max_new_tokens == 0: + api_args["generated_text"] = "" + api_args["generated_tokens"] = [] + api_args["avg_accept_tokens"] = 0 + api_args["time_per_output_token"] = 0 + return api_args + success = self.prefill(context) + + else: + if len(input_ids) == 0 or max_new_tokens == 0: + api_args["generated_text"] = "" + api_args["generated_tokens"] = [] + api_args["avg_accept_tokens"] = 0 + api_args["time_per_output_token"] = 0 + return api_args + input_ids = torch.Tensor(input_ids).long().unsqueeze(0).to(self.device) + success = self._prefill(input_ids=input_ids) + + if not success: + api_args["generated_text"] = "" + api_args["generated_tokens"] = [] + api_args["avg_accept_tokens"] = 0 + api_args["time_per_output_token"] = 0 + self.reset() + return api_args + + start = self.model.kv_cache.get_kv_len() + next_token = self.tokens[:, self.model.kv_cache.get_kv_len()-1:self.model.kv_cache.get_kv_len()] + torch.cuda.synchronize() + t1 = time.time() + + while (self.model.kv_cache.get_kv_len() - start) < max_new_tokens and self.validate_status(): + logits = self.model.inference(input_ids=next_token, position_ids=self.get_ctx(next_token)) + next_token = self.sample_tokens(logits[:, -1, :]) + self.tokens[:,self.model.kv_cache.get_kv_len():self.model.kv_cache.get_kv_len()+1] = next_token + if next_token in self.eos_tokens: + break + + torch.cuda.synchronize() + t2 = time.time() + + dec_len = (self.model.kv_cache.get_kv_len() - start + 1) + generated_text = self.tokenizer.decode( + self.tokens[0,start:self.model.kv_cache.get_kv_len()+1].tolist(), + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + + api_args["generated_text"] = generated_text + api_args["generated_tokens"] = self.tokens[0,start:self.model.kv_cache.get_kv_len()+1].tolist() + api_args["time_per_output_token"] = 1000 * (t2-t1)/dec_len + self.reset() + return api_args + + + @torch.inference_mode() + def generate_stream(self, **api_args): + """ + Gradio + """ + self.update_generation_args(**api_args) + input_ids = api_args.get("input_ids", None) + max_new_tokens = api_args.get("max_new_tokens", 128) + + if input_ids is None: + context = api_args.get("context", None) + if context is None or len(context) == 0 or max_new_tokens == 0: + api_args["generated_text"] = "" + api_args["generated_tokens"] = [] + api_args["avg_accept_tokens"] = 0 + api_args["time_per_output_token"] = 0 + return api_args + success = self.prefill(context) + + else: + if len(input_ids) == 0 or max_new_tokens == 0: + api_args["generated_text"] = "" + api_args["generated_tokens"] = [] + api_args["avg_accept_tokens"] = 0 + api_args["time_per_output_token"] = 0 + return api_args + input_ids = torch.Tensor(input_ids).long().unsqueeze(0).to(self.device) + success = self._prefill(input_ids=input_ids) + + if not success: + yield "Exceeding reserved allowed context length", "Exceeding reserved allowed context length" + + torch.cuda.synchronize() + t1 = time.time() + start = self.model.kv_cache.get_kv_len() + generated_ids = [] + pos = 0 + + partial_text = "" + + while self.validate_status(): + begin_pos = self.model.kv_cache.get_kv_len() + + logits = self.model.inference(input_ids=self.tokens[:, begin_pos - 1 : begin_pos], position_ids=self.get_ctx(self.tokens[:, begin_pos - 1 : begin_pos])) + next_token = self.sample_tokens(logits[:, -1, :]) + self.tokens[:, begin_pos : begin_pos + 1] = next_token + + new_ids = self.tokens[0, begin_pos : begin_pos + 1].tolist() + generated_ids.extend(new_ids) + + generated_text_list = ( + self.tokenizer.decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + spaces_between_special_tokens=False, + ) + .strip() + .split(" ") + ) + + now = len(generated_text_list) - 1 + + if now > pos: + new_text_chunk = " ".join(generated_text_list[pos:now]) + " " + partial_text += new_text_chunk + + t2 = time.time() + dec_len = (self.model.kv_cache.get_kv_len() - start + 1) + + perf_log = "Output Tokens {} | TPOT {:.2f} ms ".format( + dec_len, 1000 * (t2 - t1) / dec_len + ) + + yield partial_text, perf_log + + pos = now + + if ( + is_sentence_complete_regex(generated_text_list[-1]) + and (self.model.kv_cache.get_kv_len() - start >= max_new_tokens) + ) or ((self.model.kv_cache.get_kv_len() - start) >= max_new_tokens): + decode = False + + + final_piece = " ".join(generated_text_list[pos:]) + if final_piece: + partial_text += final_piece + + t2 = time.time() + dec_len = (self.model.kv_cache.get_kv_len() - start + 1) + + perf_log = "Output Tokens {} | Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format( + dec_len, 1000 * (t2 - t1) / dec_len + ) + yield partial_text, perf_log + + torch.cuda.synchronize() + t2 = time.time() + dec_len = (self.model.kv_cache.get_kv_len() - start + 1) + logger.info( + TextColors.colorize( + "Output Tokens {} | Avg Accept Tokens {:.2f} | TPOT {:.2f} ms ".format( + dec_len, 1000 * (t2 - t1) / dec_len + ), + "magenta", + ) + ) + + + self.reset() + diff --git a/umbrella/speculation/auto_engine.py b/umbrella/speculation/auto_engine.py index 0030cb2..b352516 100644 --- a/umbrella/speculation/auto_engine.py +++ b/umbrella/speculation/auto_engine.py @@ -1,8 +1,10 @@ from .dynamic_speculation_engine import DynamicSpeculationEngine from .static_speculation_engine import StaticSpeculationEngine +from .ar_engine import AREngine class AutoEngine: _ENGINE_MAPPING = { + 'ar': AREngine, 'static': StaticSpeculationEngine, 'dynamic': DynamicSpeculationEngine } @@ -17,6 +19,9 @@ def from_config(cls, device: str, **kwargs): engine_class = cls._ENGINE_MAPPING[engine_name] draft_model_name = kwargs.pop("draft_model", None) target_model_name = kwargs.pop("model", None) - assert draft_model_name is not None assert target_model_name is not None + + if draft_model_name is None: + return engine_class(model_name=target_model_name, device=device, **kwargs) + return engine_class(draft_model_name=draft_model_name, target_model_name=target_model_name,device=device, **kwargs) \ No newline at end of file