diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 676ea67..fb59a17 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, @@ -170,9 +171,16 @@ def local_chat( torch.set_default_dtype( torch.bfloat16 ) # TODO: Remove this, replace dtype using config - generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think - ) + + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: + generated = prefill_and_generate( + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, + use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim + ) + else: + generated = prefill_and_generate( + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, + ) if __name__ == "__main__": diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 95d8086..434399f 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache): page_idx = cache_position // self.page_size page_offset = cache_position % self.page_size # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) - #print("page_idx", page_idx) - #print("page_offset", page_offset) k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states return k_out, self.page_table_list[layer_idx] diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index dc5902c..adc1c5f 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, generate_device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, orig_module.max_position_embeddings, orig_module.base @@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule): **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, generate_device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.generate_device = generate_device self.prefill_device = prefill_device @@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, generate_device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, @@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, generate_device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, @@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): # **kwargs, # ): # BaseInjectedModule.__init__( -# self, key, gguf_loader, config, orig_module, generate_device, **kwargs +# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs # ) # self.generate_device = generate_device # self.prefill_device = prefill_device @@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule): **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, generate_device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.generate_device = generate_device self.prefill_device = prefill_device @@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + prefill_device: str = "cuda", + generate_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( - self, key, gguf_loader, config, orig_module, device, **kwargs + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 53dac8b..cc57997 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache -from flash_attn import flash_attn_with_kvcache, flash_attn_func +from flash_attn import flash_attn_func from ktransformers.operators.triton_attention import decode_attention_fwd_grouped import os +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +if flashinfer_enabled: + from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref + logger = logging.getLogger("attention") # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + prefill_device: str = "cuda", + generate_device: str = "cuda", chunck_size: int = 1000, - use_triton: bool = False, **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. - self.use_triton = use_triton + self.mla_wrapper = None def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): @@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): #print(compressed_kv.shape) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale + #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] compressed_kv = compressed_kv.squeeze(1) """ @@ -168,8 +173,9 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) - + attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): @@ -179,14 +185,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ) attn_output = attn_output.transpose(1, 2).contiguous() - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - def forward_linux( + def forward_linux_triton( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -232,7 +238,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # decode - if self.use_triton and q_len == 1: + if q_len == 1: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) @@ -277,7 +283,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # use triton attention kernel adapted from vLLM and SGLang for MQA decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, - position_ids.squeeze(0).to(torch.int32), attn_logits, + position_ids.squeeze(0).to(torch.int32)+1, attn_logits, 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) @@ -337,6 +343,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + + def forward_linux_flashinfer( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] + + # decode + if q_len == 1: + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank) + k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim) + # k_pe [max_pages, page_size, self.qk_rope_head_dim] + # compressed_kv [max_pages, page_size, self.kv_lora_rank] + + # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim] + # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] + q_absorb, out_absorb = self.get_absorbed() + q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below + q_nope = torch.matmul(q_nope, q_absorb) # batched MM + q_nope = q_nope.transpose(1, 2) + assert q_nope.is_contiguous() + + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] + q_nope.squeeze_(1) + q_pe.squeeze_(1) + + # flash attn doesn't support head_dim bigger than 256, use flashinfer + if self.mla_wrapper is None: + self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True) + if self.mla_wrapper.need_plan: + self.mla_wrapper.need_plan = False + self.mla_wrapper.plan(None,None,None, + position_ids.squeeze(1)+1, + self.num_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + past_key_value.page_size, + self.softmax_scale, + q_nope.dtype, + compressed_kv.dtype) + attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) + + """ + k = ( + torch.cat([compressed_kv, k_pe], dim=-1) + .view(-1, 1, 512 + 64) + .repeat_interleave(self.num_heads, dim=1) + ) + v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1) + lens = position_ids.item() + 1 + #print("lens", lens) + attn_ref, lse_ref = attention_ref( + 1, + torch.cat([q_nope, q_pe], dim=-1), + k[:lens], + v[:lens], + False, + self.softmax_scale + ) + attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank) + """ + + # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] + # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] + # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] + attn_output = attn_output.transpose(1, 2) + attn_output = torch.matmul(attn_output, out_absorb.mT) + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + else: + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + k_pe.squeeze(0) + compressed_kv.squeeze(0) + past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + k_pe.unsqueeze(0) + compressed_kv.unsqueeze(0) + + k_pe = k_pe[:, :q_len] + compressed_kv = compressed_kv[:, :q_len] + kv = ( + self.kv_b_proj(compressed_kv) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + ) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states[:, :, :, :self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + + value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states_padded, + softmax_scale=self.softmax_scale, + causal=True, + ) + + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value def forward_windows( self, @@ -415,7 +569,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode + if os.name == 'nt': return self.forward_windows( hidden_states, attention_mask, @@ -427,16 +581,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): **kwargs, ) else: - return self.forward_linux( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - **kwargs, - ) + if flashinfer_enabled: + return self.forward_linux_flashinfer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + return self.forward_linux_triton( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) class KLlamaAttention(BaseInjectedModule): @@ -447,9 +613,10 @@ class KLlamaAttention(BaseInjectedModule): gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + prefill_device: str = "cuda", + generate_device: str = "cuda", **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): diff --git a/ktransformers/operators/base_operator.py b/ktransformers/operators/base_operator.py index 1cf1471..0fa2efd 100644 --- a/ktransformers/operators/base_operator.py +++ b/ktransformers/operators/base_operator.py @@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module): gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + prefill_device: str = "cuda", + generate_device: str = "cuda", **kwargs): nn.Module.__init__(self) nn.Module.__setattr__(self, "orig_module", orig_module) object.__setattr__(self, "key", key) object.__setattr__(self, "gguf_loader", gguf_loader) object.__setattr__(self, "config", config) - object.__setattr__(self, "device", device) + object.__setattr__(self, "prefill_device", prefill_device) + object.__setattr__(self, "generate_device", generate_device) + object.__setattr__(self, "device", generate_device) def __getattr__(self, name: str) -> Any: # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 32675dc..21b4830 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase): output_cpu:Tensor = None output_gpu_map:dict = {} # Manage output tensor buffer on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu + #gguf_loader:GGUFLoader = None CPU_INFER = CPUInfer(Config().cpu_infer) def __init__( self, @@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase): **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + #if KExpertsCPU.gguf_loader is None: + # KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") + self.gguf_loader = gguf_loader assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" self.n_routed_experts = n_routed_experts self.out_device = out_device @@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py new file mode 100644 index 0000000..8d49187 --- /dev/null +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -0,0 +1,240 @@ +''' +Description : flashinfer MLA wrapper +Author : Boxin Zhang +Version : 0.2.2 +''' +import torch + +flashinfer_enabled = False + +try: + import flashinfer + flashinfer_enabled = True + print("found flashinfer") + +except ImportError: + print("flashinfer not found, use triton for linux") + +import math + +def attention_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sm_scale + ) + + #print("attn weights", logits) + + if causal: + mask = ( + torch.arange(kv_len - qo_len, kv_len).unsqueeze(1) + >= torch.arange(0, kv_len).unsqueeze(0) + ).to(q.device) + else: + mask = torch.ones(qo_len, kv_len).to(q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref, lse_ref * math.log2(math.e) + +class MLAWrapper(): + def __init__(self, + max_batch_size, + max_pages, + use_cuda_graph = True, + device = "cuda", + ): + self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device) + self.max_batch_size = max_batch_size + self.max_pages = max_pages + if use_cuda_graph: + if self.max_batch_size == 1: + self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device) + self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device) + self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) + else: + self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) + self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) + self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device) + self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device) + else: + self.qo_indptr_buf = None + self.kv_indptr_buf = None + self.kv_indices_buf = None + self.kv_len_arr_buf = None + self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.float_workspace_buffer, + use_cuda_graph=False, + qo_indptr=self.qo_indptr_buf, + kv_indptr=self.kv_indptr_buf, + kv_indices=self.kv_indices_buf, + kv_len_arr=self.kv_len_arr_buf, + ) + self.need_plan = True + + def plan(self, + qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + sm_scale, + q_data_type, + kv_data_type, + ): + if qo_indptr is None: + assert self.max_batch_size == 1 + qo_indptr = self.qo_indptr_buf + if kv_indptr is None: + assert self.max_batch_size == 1 + kv_indptr = self.kv_indptr_buf + if kv_indices is None: + assert self.max_batch_size == 1 + kv_indices = self.kv_indices_buf + + self.wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + False, # causal is False for decoding + sm_scale, + q_data_type, + kv_data_type, + ) + + def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): + return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse) + +class MLAWrapperSingleton(): + wrappers:dict = {} + + @classmethod + def get_instance(cls, device, *args, **kwargs)->MLAWrapper: + if device not in cls.wrappers: + cls.make_instance(device, *args, **kwargs) + return cls.wrappers[device] + + @classmethod + def make_instance(cls, device, *args, **kwargs): + cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device) + + @classmethod + def plan_all(cls, qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + sm_scale, + q_data_type, + kv_data_type,): + for device, wrapper in cls.wrappers.items(): + kv_len_arr_cur_device = kv_len_arr.to(device) + wrapper.plan(qo_indptr, + kv_indptr, + kv_indices, + kv_len_arr_cur_device, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + sm_scale, + q_data_type, + kv_data_type,) + + +if __name__ == "__main__": + max_batch_size = 1 + max_pages = 1 + page_size = 64 + num_heads = 128 + + q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda") + q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda") + ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") + k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda") + + + wrapper = MLAWrapperSingleton.get_instance( + "cuda", + max_batch_size, + max_pages, + ) + + kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda") + + wrapper.plan( + None, + None, + None, + kv_len_arr, + 128, + 512, + 64, + page_size, + 192 ** (-0.5), + torch.bfloat16, + torch.bfloat16, + ) + + attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) + + k = ( + torch.cat([ckv, k_pe], dim=-1) + .view(-1, 1, 512 + 64) + .repeat_interleave(num_heads, dim=1) + ) + v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) + + print(k[:10].shape) + print(v[:10].shape) + + attn_ref, lse_ref = attention_ref( + max_batch_size, + torch.cat([q_nope, q_pe], dim=-1), + k[:10], + v[:10], + False, + 192 ** (-0.5) + ) + + torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3) + print("test past") \ No newline at end of file diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index ab7d0b2..52bb33a 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - generate_device: str = "cuda", prefill_device: str = "cuda", + generate_device: str = "cuda", **kwargs, ): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_device = generate_device self.prefill_device = prefill_device diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index df01ac9..08a2cca 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): prefill_op: str| None = "KLinearTorch", **kwargs, ): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) # build all the linear operators if prefill_op is not None: diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 62059a7..eaa1a7d 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -109,6 +109,7 @@ GGML_TYPES = { "Q5_K": 13, "Q6_K": 14, "IQ4_XS": 23, + "BF16": 30, } GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} @@ -116,6 +117,7 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_BLOCK_SIZES = { "F32": 4, "F16": 2, + "BF16": 2, "Q4_0": 2 + 16, "Q5_0": 2 + 4 + 16, "Q8_0": 2 + 32, @@ -130,6 +132,7 @@ GGML_BLOCK_SIZES = { GGML_ELEMENTS_PER_BLOCK = { "F32": 1, "F16": 1, + "BF16": 1, "Q4_0": 32, "Q5_0": 32, "Q8_0": 32, @@ -333,6 +336,8 @@ class GGUFLoader: else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values) + if ggml_name == "BF16": + values = values.view(torch.bfloat16) values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] @@ -764,6 +769,7 @@ def dequantize_f16_gpu(data, device): GGML_DEQUANTIZE = { "F32": dequantize_f32, "F16": dequantize_f16, + "BF16": dequantize_f16, "Q4_0": dequantize_q4_0, "Q5_0": dequantize_q5_0, "Q8_0": dequantize_q8_0, @@ -778,6 +784,7 @@ GGML_DEQUANTIZE = { GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, + "BF16": dequantize_f16_gpu, "Q4_0": dequantize_q4_0_gpu, "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu, diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 5db643a..7034ac9 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -17,6 +17,7 @@ from ktransformers.operators import base_operator from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer +from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton warm_uped = False @@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): module.load() def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, - mode = 'normal', force_think: bool = False): + mode = 'normal', force_think: bool = False, use_flashinfer_mla = False, + num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" torch._dynamo.config.suppress_errors = True @@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ) else: past_key_values = None - cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long) + cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32) generated_ids = torch.zeros( batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device ) @@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud generated_ids[:, seq_length] = next_token tokens.append(int(next_token)) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) - cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long) + cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32) position_ids = cache_position.unsqueeze(0) seq_length += 1 @@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud warm_uped = True cuda_graph_runner = CUDAGraphRunner() cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) - + if i > 1 and use_flashinfer_mla: + MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1, + num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, + q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int()