diff --git a/examples/generate.py b/examples/generate.py index f130030..49038fa 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -6,7 +6,7 @@ logger = setup_logger() import torch from umbrella.templates import Prompts, SysPrompts -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM, MistralForCausalLM from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, find_first_element_position import argparse import time @@ -30,6 +30,7 @@ text = system_prompt + text tokenizer = AutoTokenizer.from_pretrained(args.model) + tokens = tokenizer.encode(text=text, return_tensors="pt").to(DEVICE) llm = AutoModelLM.from_pretrained( diff --git a/examples/generate_directly.py b/examples/generate_directly.py new file mode 100644 index 0000000..fb74c8d --- /dev/null +++ b/examples/generate_directly.py @@ -0,0 +1,44 @@ +# Load model directly +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, MistralForCausalLM +from umbrella.speculation.speculation_utils import make_causal_mask, is_sentence_complete_regex, find_first_element_position + +DEVICE = "cuda:0" +MAX_LEN = 2048 + +attention_mask = make_causal_mask((MAX_LEN, MAX_LEN), DEVICE) + +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3") + +model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.3", torch_dtype=torch.float16, _attn_implementation="eager").to(DEVICE) + +# # tokenizer.add_special_tokens({'pad_token_id': '[PAD]'}) +# tokenizer.padding_side = 'right' +# tokenizer.add_eos_token = True +# tokenizer.pad_token_id=2041 +# eos_token_id=tokenizer.eos_token_id +# model.resize_token_embeddings(len(tokenizer)) +# model.config.pad_token_id = tokenizer.pad_token_id + +# model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.3", torch_dtype=torch.float16, '''_attn_implementation="eager"''' max_length=MAX_LEN, attention_mask=attention_mask).to("cuda:0") +# text = "Tell me what you know about Reinforcement Learning in 100 words." +text = "[INST] Tell me what you know about Reinforcement Learning in 100 words.[/INST]" + +# messages = [{"role": "user", "content": text}] + +# # Modified template application +# prompt = tokenizer.apply_chat_template( +# messages, +# tokenize=False, +# add_generation_prompt=True # Critical for response triggering +# ) + + +input_ids = tokenizer.encode(text=text, return_tensors="pt").to(DEVICE) + +# input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +# prefix_len = input_ids.shape[1] + +output = model.generate(input_ids, do_sample=False, max_new_tokens=512) +print(tokenizer.decode(output[0], skip_special_tokens=True)) diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 5c60c1f..eb51c35 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -1,6 +1,6 @@ from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph -from .gemma import Gemma2 +from .mistral import Mistral, MistralAwq, MistralOffload, MistralAwqOffload, MistralCudagraph class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -17,6 +17,9 @@ class AutoModelLM: "meta-llama/Llama-3.1-8B-Instruct": LlamaOffload, "meta-llama/Meta-Llama-3-70B-Instruct": LlamaOffload, "meta-llama/Meta-Llama-3-8B-Instruct": LlamaOffload, + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B":LlamaOffload, + "deepseek-ai/DeepSeek-R1-Distill-Llama-70B":LlamaOffload, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":QwenOffload, "Qwen/Qwen2.5-Coder-72B-Instruct": QwenOffload, "Qwen/Qwen2.5-Coder-32B-Instruct": QwenOffload, "Qwen/Qwen2.5-Coder-14B-Instruct": QwenOffload, @@ -47,8 +50,8 @@ class AutoModelLM: "Qwen/Qwen2.5-32B-Instruct-AWQ": QwenAwqOffload, "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwqOffload, "KirillR/QwQ-32B-Preview-AWQ": QwenAwqOffload, - "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload - + "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwqOffload, + "mistralai/Mistral-7B-v0.3": MistralOffload, # Mistral 7B added by EJ } _MODEL_MAPPING = { @@ -73,6 +76,12 @@ class AutoModelLM: "Zhuominc/Coder-400M-IT": Llama, "Zhuominc/FastCode-500M": Llama, "InfiniAILab/CodeDrafter-500M": Llama, + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B":Llama, + "deepseek-ai/DeepSeek-R1-Distill-Llama-70B":Llama, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":Qwen, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B":Qwen, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B":Qwen, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B":Qwen, "Qwen/Qwen2.5-Coder-72B-Instruct": Qwen, "Qwen/Qwen2.5-Coder-32B-Instruct": Qwen, "Qwen/Qwen2.5-Coder-14B-Instruct": Qwen, @@ -104,8 +113,7 @@ class AutoModelLM: "Qwen/Qwen2.5-72B-Instruct-AWQ": QwenAwq, "KirillR/QwQ-32B-Preview-AWQ": QwenAwq, "casperhansen/deepseek-r1-distill-qwen-32b-awq":QwenAwq, - "google/gemma-2-2b-it": Gemma2, - "google/gemma-2-2b": Gemma2 + "mistralai/Mistral-7B-v0.3": Mistral, # Mistral 7B added by EJ } _CUDAGRAPH_MODEL_MAPPING = { @@ -122,6 +130,7 @@ class AutoModelLM: "Zhuominc/Coder-400M-IT": LlamaCudagraph, "Zhuominc/FastCode-500M": LlamaCudagraph, "InfiniAILab/CodeDrafter-500M": LlamaCudagraph, + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B":QwenCudagraph, "Qwen/Qwen2.5-Coder-72B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-Coder-32B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-Coder-14B-Instruct": QwenCudagraph, @@ -136,7 +145,8 @@ class AutoModelLM: "Qwen/Qwen2.5-14B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-32B-Instruct": QwenCudagraph, "Qwen/Qwen2.5-72B-Instruct": QwenCudagraph, - "Qwen/QwQ-32B-Preview": QwenCudagraph + "Qwen/QwQ-32B-Preview": QwenCudagraph, + "mistralai/Mistral-7B-v0.3": MistralCudagraph, # Mistral 7B added by EJ } @classmethod @@ -165,4 +175,4 @@ def from_pretrained(cls, model_name, offload=False, cuda_graph=False, **kwargs): raise ValueError(f"Model type '{model_name}' is not supported (offload). " f"Supported (offload) types: {list(cls._OFFLOAD_MODEL_MAPPING.keys())}") model_class = cls._OFFLOAD_MODEL_MAPPING[model_name] - return model_class(model_name = model_name, **kwargs) + return model_class(model_name = model_name, **kwargs) \ No newline at end of file diff --git a/umbrella/models/mistral.py b/umbrella/models/mistral.py new file mode 100644 index 0000000..ed9e80e --- /dev/null +++ b/umbrella/models/mistral.py @@ -0,0 +1,531 @@ +from transformers import MistralModel, MistralForCausalLM, MistralConfig, AutoModelForCausalLM, AutoTokenizer +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .mistral_layer import MistralLayer, MistralAwqLayer, MistralPackedLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +class Mistral(LLMBase): + def __init__(self, + model_name: str, + batch_size: int=1, + max_length: int=256, + device: str = "cuda:0", + dtype = torch.float16) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = MistralConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() + + +class MistralOffload(Mistral): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + self.load_stream = torch.cuda.Stream(device=device) + + def alloc(self, **kwargs): + + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = MistralLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [MistralLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + if self.buffer[0].layer_idx != 0: + self.buffer[0].copy(self.layers[0]) + torch.cuda.synchronize() + for idx in range(self.num_layers): + with torch.cuda.stream(self.load_stream): + self.buffer[(idx + 1) % 2].copy(self.layers[(idx + 1)% self.num_layers]) + + hidden_states = self.layer_compute(self.buffer[idx % 2], idx, hidden_states, position_ids, attention_mask, storage_ids) + torch.cuda.synchronize() + b, s, h = hidden_states.shape + + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + +class MistralAwq(Mistral): + def alloc(self, **kwargs): + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + self.layers :list[MistralAwqLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = MistralAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralAwqLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids) + + b, s, h = hidden_states.shape + hidden_states = hidden_states.reshape(b * s, h) + hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon) + hidden_states = hidden_states.reshape(b, s, h) + logits = F.linear(hidden_states, self.lm_head).float() + return logits + + +class MistralAwqOffload(MistralOffload): + def alloc(self, **kwargs): + + self.num_cache_layers = kwargs["num_cache_layers"] if 'num_cache_layers' in kwargs else 0 + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + self.layers :list[MistralAwqLayer] = [] + + for idx, hf_layer in tqdm(enumerate(hf_model.model.layers), desc="initial offloaded model"): + layer = MistralAwqLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + if idx < self.num_cache_layers: + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + self.num_layers = len(self.layers) + assert self.num_layers % 2 == 0 + self.buffer = [MistralAwqLayer(-1, self.device) for _ in range(2)] + self.buffer[0].alloc_space(self.layers[0], self.device) + self.buffer[1].alloc_space(self.layers[0], self.device) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralAwqLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + query_states = buffer.wq.apply(hidden_states) + key_states = buffer.wk.apply(hidden_states) + value_states = buffer.wv.apply(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + hidden_states = buffer.wo.apply(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = buffer.up_proj.apply(hidden_states) + gate = buffer.gate_proj.apply(hidden_states) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = buffer.down_proj.apply(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MistralCudagraph(Mistral): + def __init__(self, model_name, batch_size = 1, max_length = 256, device = 'cuda:0', dtype=torch.float16): + super().__init__(model_name, batch_size, max_length, device, dtype) + + self.callables = {} + self.mempool = None + + def alloc(self, **kwargs): + + exit_layer = kwargs.pop("exit_layer", -1) + self.kv_cache = StaticKV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + hf_model = MistralForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + if self.config.tie_word_embeddings: + self.lm_head = self.embed_tokens + else: + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device) + self.attention_scaling = hf_model.model.rotary_emb.attention_scaling + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + self.cos_cache = self.cos_cache * self.attention_scaling + self.sin_cache = self.sin_cache * self.attention_scaling + self.cos_cache = self.cos_cache.to(self.dtype) + self.sin_cache = self.sin_cache.to(self.dtype) + + self.layers :list[MistralPackedLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + if exit_layer > 0 and idx >= exit_layer: + break + layer = MistralPackedLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + @torch.inference_mode() + def layer_compute(self, + buffer: MistralPackedLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + bsz, q_len, _ = hidden_states.size() + qkv = F.linear(hidden_states, buffer.wqkv) + query_states = qkv[...,:self.hidden_size] + key_states = qkv[...,self.hidden_size:self.hidden_size + self.head_dim * self.num_key_value_heads] + value_states = qkv[...,self.hidden_size + self.head_dim * self.num_key_value_heads:] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2) + + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids, unsqueeze_dim=1) + hidden_states = self.kv_cache.compute_attention( + query_states, key_states, value_states, layer_idx, storage_ids, attention_mask + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + hidden_states = residual + hidden_states + + return hidden_states + + + @torch.inference_mode() + def initialize_cuda_graph(self, + decoding_seqlens :list[int], + n_warmups=12): + gc.collect() + self.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if decoding_seqlen not in self.callables: + self.callables[decoding_seqlen] = capture_graph( + llm=self, + decoding_seqlen=decoding_seqlen, + mempool=self.mempool, + n_warmups=n_warmups + ) + self.clear() + + @torch.inference_mode() + def graph_inference(self, + input_ids: torch.LongTensor, + storage_ids :torch.LongTensor, + position_ids = None, + attention_mask = None, + ): + dec_length = input_ids.shape[1] + if dec_length in self.callables.keys(): + logits = self.callables[dec_length](input_ids, storage_ids, position_ids, attention_mask) + else: + logits = self.inference(input_ids, position_ids, attention_mask, storage_ids) + return logits + diff --git a/umbrella/models/mistral_layer.py b/umbrella/models/mistral_layer.py new file mode 100644 index 0000000..be05520 --- /dev/null +++ b/umbrella/models/mistral_layer.py @@ -0,0 +1,258 @@ +from __future__ import annotations +import torch +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from ..quantization.awq_utils import AwqLinear + +class MistralLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: MistralLayer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralLayer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class MistralPackedLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wqkv :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: MistralDecoderLayer): + + self.wqkv :torch.Tensor= torch.cat( + [ + hf_layer.self_attn.q_proj.weight.detach(), + hf_layer.self_attn.k_proj.weight.detach(), + hf_layer.self_attn.v_proj.weight.detach(), + ], + dim=0 + ) + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wqkv = self.wqkv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: MistralPackedLayer): + + self.wqkv.copy_(layer.wqkv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx + + def alloc_space(self, layer: MistralPackedLayer, device): + + self.device = device + self.wqkv = torch.zeros_like(layer.wqkv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + +class MistralAwqLayer(): + def __init__(self, layer_idx, device="cpu"): + + self.wq = AwqLinear() + self.wk = AwqLinear() + self.wv = AwqLinear() + self.wo = AwqLinear() + + self.gate_proj = AwqLinear() + self.up_proj = AwqLinear() + self.down_proj = AwqLinear() + + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer): + + self.wq.init_parameters(hf_layer.self_attn.q_proj) + self.wk.init_parameters(hf_layer.self_attn.k_proj) + self.wv.init_parameters(hf_layer.self_attn.v_proj) + self.wo.init_parameters(hf_layer.self_attn.o_proj) + self.gate_proj.init_parameters(hf_layer.mlp.gate_proj) + self.up_proj.init_parameters(hf_layer.mlp.up_proj) + self.down_proj.init_parameters(hf_layer.mlp.down_proj) + + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach().pin_memory() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach().pin_memory() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + def alloc_space(self, layer: MistralAwqLayer, device): + + self.device = device + self.wq.empty_like(layer.wq) + self.wk.empty_like(layer.wk) + self.wv.empty_like(layer.wv) + self.wo.empty_like(layer.wo) + + self.gate_proj.empty_like(layer.gate_proj) + self.up_proj.empty_like(layer.up_proj) + self.down_proj.empty_like(layer.down_proj) + + self.wq.to(device=device) + self.wk.to(device=device) + self.wv.to(device=device) + self.wo.to(device=device) + + self.gate_proj.to(device=device) + self.up_proj.to(device=device) + self.down_proj.to(device=device) + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) + + def copy(self, layer: MistralAwqLayer): + + self.wq.copy(layer.wq, non_blocking=True) + self.wk.copy(layer.wk, non_blocking=True) + self.wv.copy(layer.wv, non_blocking=True) + self.wo.copy(layer.wo, non_blocking=True) + self.gate_proj.copy(layer.gate_proj, non_blocking=True) + self.up_proj.copy(layer.up_proj, non_blocking=True) + self.down_proj.copy(layer.down_proj, non_blocking=True) + + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + self.layer_idx = layer.layer_idx \ No newline at end of file diff --git a/umbrella/templates.py b/umbrella/templates.py index c71dc5f..cafcca8 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -6,6 +6,19 @@ <|start_header_id|>assistant<|end_header_id|> """, + +# <|begin_of_text|> +# <|start_header_id|>system<|end_header_id|> + +# {{ system_prompt }}<|eot_id|> + +# <|start_header_id|>user<|end_header_id|> +# {{ user_message_1 }}<|eot_id|> +# <|start_header_id|>assistant<|end_header_id|> + +# {{ model_response_1 }}<|eot_id|> + + 'llama3-code':"""<|start_header_id|>user<|end_header_id|> {}<|eot_id|><|start_header_id|>assistant<|end_header_id|> @@ -14,14 +27,40 @@ 'qwen': """<|im_start|>user {}<|im_end|> <|im_start|>assistant -""", +""", + # 'mistral-7b': """[INST] {}[/INST] """, + # 'mistral-7b': """[INST]{}[/INST]""", + # 'mistral-7b': """[INST]{}[/INST]""", + # 'mistral-7b': """[INST] {}[/INST] """, + # 'mistral-7b': """ + # 'mistral-7b': """[INST] {} [/INST]""", +# 'mistral-7b': """[INST] {} [/INST] +# """ + # 'mistral-7b': """[INST] {} [/INST] .""", + 'mistral-7b': """[INST] {}[/INST]""", + # 'mistral-7b': """[INST] {}[/INST]""", + # 'mistral-7b': """[INST] {} [/INST] """ -'gemma2-it': """user -{} -model -""", - -'gemma2': "{}" + # 'mistral-7b': """[INST] {}[/INST]""" ## current best + # 'mistral-7b': """[INST] {} [/INST] """ + # 'mistral-7b': """{}""" + # 'mistral-7b': """[INST] {} [/INST] """ + # 'mistral-7b': """[INST] {} [/INST]""" + # 'mistral-7b': """[INST] {} """ + # 'mistral-7b': """[INST] {}""" + # 'mistral-7b': """[INST]{}""" + # 'mistral-7b': """[INST] {}[/INST]""" + # 'mistral-7b': """[INST] {} [/INST]\n""" + # 'mistral-7b': """[INST] {} [/INST]""" + + # # {} [/INST] """, + # 'mistral-7b': """{}[/INST]""" + + # 'mistral-7b': """[INST]user message[/INST]assistant message[INST]new user message[/INST]""", + + # 'mistral-7b': """[INST] {}[/INST] assistant message""", + + } SysPrompts = { @@ -34,11 +73,22 @@ 'qwen': """<|im_start|>system You are a helpful assistant.<|im_end|> """, - 'gemma2': "", - 'gemma2-it': "" +# 'mistral-7b': """[INST] system prompt + +# new user message[/INST]""", + # 'mistral-7b': """[INST] [/INST]""", + # 'mistral-7b': """[INST]""", + 'mistral-7b': """""" + # 'mistral-7b': """""" + # 'mistral-7b': """\n""" + +# 'mistral-7b': """[INST] user message[/INST] assistant message[INST] system prompt + +# new user message[/INST]""", } ExtraPrompts = { 'llama3-code': """\nAlways try to wrap what you write in a function.""" } +