mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Merge branch 'fix_precision_MLA' of https://github.com/kvcache-ai/ktransformers into server-prefix-cache
This commit is contained in:
commit
bb1cadfff3
11 changed files with 479 additions and 46 deletions
|
@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
from ktransformers.util.utils import prefill_and_generate
|
from ktransformers.util.utils import prefill_and_generate
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||||
|
|
||||||
custom_models = {
|
custom_models = {
|
||||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||||
|
@ -170,9 +171,16 @@ def local_chat(
|
||||||
torch.set_default_dtype(
|
torch.set_default_dtype(
|
||||||
torch.bfloat16
|
torch.bfloat16
|
||||||
) # TODO: Remove this, replace dtype using config
|
) # TODO: Remove this, replace dtype using config
|
||||||
generated = prefill_and_generate(
|
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
|
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
|
||||||
)
|
generated = prefill_and_generate(
|
||||||
|
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||||
|
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated = prefill_and_generate(
|
||||||
|
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
|
||||||
page_idx = cache_position // self.page_size
|
page_idx = cache_position // self.page_size
|
||||||
page_offset = cache_position % self.page_size
|
page_offset = cache_position % self.page_size
|
||||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||||
#print("page_idx", page_idx)
|
|
||||||
#print("page_offset", page_offset)
|
|
||||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||||
return k_out, self.page_table_list[layer_idx]
|
return k_out, self.page_table_list[layer_idx]
|
||||||
|
|
|
@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.orig_module.__init__(
|
self.orig_module.__init__(
|
||||||
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
|
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
|
||||||
|
@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.generate_device = generate_device
|
self.generate_device = generate_device
|
||||||
self.prefill_device = prefill_device
|
self.prefill_device = prefill_device
|
||||||
|
@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.orig_module.__init__(
|
self.orig_module.__init__(
|
||||||
orig_module.dim,
|
orig_module.dim,
|
||||||
|
@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.orig_module.__init__(
|
self.orig_module.__init__(
|
||||||
orig_module.dim,
|
orig_module.dim,
|
||||||
|
@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
||||||
# **kwargs,
|
# **kwargs,
|
||||||
# ):
|
# ):
|
||||||
# BaseInjectedModule.__init__(
|
# BaseInjectedModule.__init__(
|
||||||
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
# )
|
# )
|
||||||
# self.generate_device = generate_device
|
# self.generate_device = generate_device
|
||||||
# self.prefill_device = prefill_device
|
# self.prefill_device = prefill_device
|
||||||
|
@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.generate_device = generate_device
|
self.generate_device = generate_device
|
||||||
self.prefill_device = prefill_device
|
self.prefill_device = prefill_device
|
||||||
|
@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
|
||||||
gguf_loader: GGUFLoader,
|
gguf_loader: GGUFLoader,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module,
|
orig_module: nn.Module,
|
||||||
device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(
|
BaseInjectedModule.__init__(
|
||||||
self, key, gguf_loader, config, orig_module, device, **kwargs
|
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||||
)
|
)
|
||||||
self.orig_module.__init__(
|
self.orig_module.__init__(
|
||||||
orig_module.dim,
|
orig_module.dim,
|
||||||
|
|
|
@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
import logging
|
import logging
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from flash_attn import flash_attn_with_kvcache, flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||||
import os
|
import os
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||||
|
if flashinfer_enabled:
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
|
||||||
|
|
||||||
logger = logging.getLogger("attention")
|
logger = logging.getLogger("attention")
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
gguf_loader : GGUFLoader,
|
gguf_loader : GGUFLoader,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module,
|
orig_module: nn.Module,
|
||||||
device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
chunck_size: int = 1000,
|
chunck_size: int = 1000,
|
||||||
use_triton: bool = False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
self.orig_module.__init__(orig_module.config,
|
self.orig_module.__init__(orig_module.config,
|
||||||
orig_module.layer_idx)
|
orig_module.layer_idx)
|
||||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||||
self.use_triton = use_triton
|
self.mla_wrapper = None
|
||||||
|
|
||||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||||
|
@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
#print(compressed_kv.shape)
|
#print(compressed_kv.shape)
|
||||||
|
|
||||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
||||||
|
|
||||||
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
|
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
|
||||||
compressed_kv = compressed_kv.squeeze(1)
|
compressed_kv = compressed_kv.squeeze(1)
|
||||||
"""
|
"""
|
||||||
|
@ -168,6 +173,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_weights = nn.functional.dropout(
|
attn_weights = nn.functional.dropout(
|
||||||
attn_weights, p=self.attention_dropout, training=self.training
|
attn_weights, p=self.attention_dropout, training=self.training
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
||||||
|
|
||||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||||
|
@ -186,7 +192,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
def forward_linux(
|
def forward_linux_triton(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
@ -232,7 +238,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
if self.use_triton and q_len == 1:
|
if q_len == 1:
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
|
@ -277,7 +283,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
||||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||||
page_table,
|
page_table,
|
||||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
|
||||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
past_key_value.page_size)
|
past_key_value.page_size)
|
||||||
|
@ -338,6 +344,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
def forward_linux_flashinfer(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||||
|
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||||
|
q_nope, q_pe = torch.split(
|
||||||
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, k_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||||
|
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||||
|
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||||
|
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||||
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||||
|
|
||||||
|
# decode
|
||||||
|
if q_len == 1:
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
|
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)
|
||||||
|
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)
|
||||||
|
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
|
||||||
|
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
|
||||||
|
|
||||||
|
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
|
||||||
|
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
|
||||||
|
q_absorb, out_absorb = self.get_absorbed()
|
||||||
|
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||||
|
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||||
|
q_nope = q_nope.transpose(1, 2)
|
||||||
|
assert q_nope.is_contiguous()
|
||||||
|
|
||||||
|
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||||
|
q_nope.squeeze_(1)
|
||||||
|
q_pe.squeeze_(1)
|
||||||
|
|
||||||
|
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||||
|
if self.mla_wrapper is None:
|
||||||
|
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||||
|
if self.mla_wrapper.need_plan:
|
||||||
|
self.mla_wrapper.need_plan = False
|
||||||
|
self.mla_wrapper.plan(None,None,None,
|
||||||
|
position_ids.squeeze(1)+1,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
past_key_value.page_size,
|
||||||
|
self.softmax_scale,
|
||||||
|
q_nope.dtype,
|
||||||
|
compressed_kv.dtype)
|
||||||
|
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
|
"""
|
||||||
|
k = (
|
||||||
|
torch.cat([compressed_kv, k_pe], dim=-1)
|
||||||
|
.view(-1, 1, 512 + 64)
|
||||||
|
.repeat_interleave(self.num_heads, dim=1)
|
||||||
|
)
|
||||||
|
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
|
||||||
|
lens = position_ids.item() + 1
|
||||||
|
#print("lens", lens)
|
||||||
|
attn_ref, lse_ref = attention_ref(
|
||||||
|
1,
|
||||||
|
torch.cat([q_nope, q_pe], dim=-1),
|
||||||
|
k[:lens],
|
||||||
|
v[:lens],
|
||||||
|
False,
|
||||||
|
self.softmax_scale
|
||||||
|
)
|
||||||
|
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
||||||
|
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
|
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
else:
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
k_pe.squeeze(0)
|
||||||
|
compressed_kv.squeeze(0)
|
||||||
|
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
|
k_pe.unsqueeze(0)
|
||||||
|
compressed_kv.unsqueeze(0)
|
||||||
|
|
||||||
|
k_pe = k_pe[:, :q_len]
|
||||||
|
compressed_kv = compressed_kv[:, :q_len]
|
||||||
|
kv = (
|
||||||
|
self.kv_b_proj(compressed_kv)
|
||||||
|
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
)
|
||||||
|
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||||
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||||
|
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||||
|
|
||||||
|
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||||
|
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||||
|
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||||
|
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||||
|
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||||
|
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states_padded,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_head_dim != self.v_head_dim:
|
||||||
|
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(
|
||||||
|
bsz, q_len, self.num_heads * self.v_head_dim
|
||||||
|
).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
def forward_windows(
|
def forward_windows(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -415,7 +569,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode
|
if os.name == 'nt':
|
||||||
return self.forward_windows(
|
return self.forward_windows(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
@ -427,16 +581,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_linux(
|
if flashinfer_enabled:
|
||||||
hidden_states,
|
return self.forward_linux_flashinfer(
|
||||||
attention_mask,
|
hidden_states,
|
||||||
position_ids,
|
attention_mask,
|
||||||
past_key_value,
|
position_ids,
|
||||||
output_attentions,
|
past_key_value,
|
||||||
use_cache,
|
output_attentions,
|
||||||
cache_position,
|
use_cache,
|
||||||
**kwargs,
|
cache_position,
|
||||||
)
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.forward_linux_triton(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KLlamaAttention(BaseInjectedModule):
|
class KLlamaAttention(BaseInjectedModule):
|
||||||
|
@ -447,9 +613,10 @@ class KLlamaAttention(BaseInjectedModule):
|
||||||
gguf_loader : GGUFLoader,
|
gguf_loader : GGUFLoader,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module,
|
orig_module: nn.Module,
|
||||||
device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
self.orig_module.__init__(orig_module.config,
|
self.orig_module.__init__(orig_module.config,
|
||||||
orig_module.layer_idx)
|
orig_module.layer_idx)
|
||||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
|
|
@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
|
||||||
gguf_loader : GGUFLoader,
|
gguf_loader : GGUFLoader,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module,
|
orig_module: nn.Module,
|
||||||
device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
nn.Module.__setattr__(self, "orig_module", orig_module)
|
nn.Module.__setattr__(self, "orig_module", orig_module)
|
||||||
object.__setattr__(self, "key", key)
|
object.__setattr__(self, "key", key)
|
||||||
object.__setattr__(self, "gguf_loader", gguf_loader)
|
object.__setattr__(self, "gguf_loader", gguf_loader)
|
||||||
object.__setattr__(self, "config", config)
|
object.__setattr__(self, "config", config)
|
||||||
object.__setattr__(self, "device", device)
|
object.__setattr__(self, "prefill_device", prefill_device)
|
||||||
|
object.__setattr__(self, "generate_device", generate_device)
|
||||||
|
object.__setattr__(self, "device", generate_device)
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
|
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
|
||||||
|
|
|
@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
|
||||||
output_cpu:Tensor = None
|
output_cpu:Tensor = None
|
||||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||||
|
#gguf_loader:GGUFLoader = None
|
||||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||||
|
#if KExpertsCPU.gguf_loader is None:
|
||||||
|
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||||
|
self.gguf_loader = gguf_loader
|
||||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||||
self.n_routed_experts = n_routed_experts
|
self.n_routed_experts = n_routed_experts
|
||||||
self.out_device = out_device
|
self.out_device = out_device
|
||||||
|
@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
||||||
generate_device: str = "cpu",
|
generate_device: str = "cpu",
|
||||||
generate_op: str | None = "KExpertsCPU",
|
generate_op: str | None = "KExpertsCPU",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
if generate_op is not None:
|
if generate_op is not None:
|
||||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||||
|
|
240
ktransformers/operators/flashinfer_wrapper.py
Normal file
240
ktransformers/operators/flashinfer_wrapper.py
Normal file
|
@ -0,0 +1,240 @@
|
||||||
|
'''
|
||||||
|
Description : flashinfer MLA wrapper
|
||||||
|
Author : Boxin Zhang
|
||||||
|
Version : 0.2.2
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
|
||||||
|
flashinfer_enabled = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flashinfer
|
||||||
|
flashinfer_enabled = True
|
||||||
|
print("found flashinfer")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("flashinfer not found, use triton for linux")
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
def attention_ref(
|
||||||
|
batch_size,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
causal: bool,
|
||||||
|
sm_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qo_len = q.shape[0] // batch_size
|
||||||
|
kv_len = k.shape[0] // batch_size
|
||||||
|
num_qo_heads = q.shape[1]
|
||||||
|
head_dim_qk = q.shape[2]
|
||||||
|
head_dim_vo = v.shape[2]
|
||||||
|
logits = (
|
||||||
|
torch.einsum(
|
||||||
|
"bmhd,bnhd->bhmn",
|
||||||
|
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
|
||||||
|
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
|
||||||
|
)
|
||||||
|
* sm_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
#print("attn weights", logits)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
mask = (
|
||||||
|
torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)
|
||||||
|
>= torch.arange(0, kv_len).unsqueeze(0)
|
||||||
|
).to(q.device)
|
||||||
|
else:
|
||||||
|
mask = torch.ones(qo_len, kv_len).to(q.device)
|
||||||
|
|
||||||
|
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
||||||
|
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
|
||||||
|
p = torch.softmax(logits, dim=-1)
|
||||||
|
o_ref = (
|
||||||
|
torch.einsum(
|
||||||
|
"bhmn,bnhd->bmhd",
|
||||||
|
p,
|
||||||
|
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
|
||||||
|
)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
|
||||||
|
.to(q)
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_ref, lse_ref * math.log2(math.e)
|
||||||
|
|
||||||
|
class MLAWrapper():
|
||||||
|
def __init__(self,
|
||||||
|
max_batch_size,
|
||||||
|
max_pages,
|
||||||
|
use_cuda_graph = True,
|
||||||
|
device = "cuda",
|
||||||
|
):
|
||||||
|
self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_pages = max_pages
|
||||||
|
if use_cuda_graph:
|
||||||
|
if self.max_batch_size == 1:
|
||||||
|
self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)
|
||||||
|
self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)
|
||||||
|
self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||||
|
else:
|
||||||
|
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||||
|
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||||
|
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
|
||||||
|
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
|
||||||
|
else:
|
||||||
|
self.qo_indptr_buf = None
|
||||||
|
self.kv_indptr_buf = None
|
||||||
|
self.kv_indices_buf = None
|
||||||
|
self.kv_len_arr_buf = None
|
||||||
|
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||||
|
self.float_workspace_buffer,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
qo_indptr=self.qo_indptr_buf,
|
||||||
|
kv_indptr=self.kv_indptr_buf,
|
||||||
|
kv_indices=self.kv_indices_buf,
|
||||||
|
kv_len_arr=self.kv_len_arr_buf,
|
||||||
|
)
|
||||||
|
self.need_plan = True
|
||||||
|
|
||||||
|
def plan(self,
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_len_arr,
|
||||||
|
num_heads,
|
||||||
|
head_dim_ckv,
|
||||||
|
head_dim_kpe,
|
||||||
|
page_size,
|
||||||
|
sm_scale,
|
||||||
|
q_data_type,
|
||||||
|
kv_data_type,
|
||||||
|
):
|
||||||
|
if qo_indptr is None:
|
||||||
|
assert self.max_batch_size == 1
|
||||||
|
qo_indptr = self.qo_indptr_buf
|
||||||
|
if kv_indptr is None:
|
||||||
|
assert self.max_batch_size == 1
|
||||||
|
kv_indptr = self.kv_indptr_buf
|
||||||
|
if kv_indices is None:
|
||||||
|
assert self.max_batch_size == 1
|
||||||
|
kv_indices = self.kv_indices_buf
|
||||||
|
|
||||||
|
self.wrapper.plan(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_len_arr,
|
||||||
|
num_heads,
|
||||||
|
head_dim_ckv,
|
||||||
|
head_dim_kpe,
|
||||||
|
page_size,
|
||||||
|
False, # causal is False for decoding
|
||||||
|
sm_scale,
|
||||||
|
q_data_type,
|
||||||
|
kv_data_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||||
|
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
||||||
|
|
||||||
|
class MLAWrapperSingleton():
|
||||||
|
wrappers:dict = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls, device, *args, **kwargs)->MLAWrapper:
|
||||||
|
if device not in cls.wrappers:
|
||||||
|
cls.make_instance(device, *args, **kwargs)
|
||||||
|
return cls.wrappers[device]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_instance(cls, device, *args, **kwargs):
|
||||||
|
cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def plan_all(cls, qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_len_arr,
|
||||||
|
num_heads,
|
||||||
|
head_dim_ckv,
|
||||||
|
head_dim_kpe,
|
||||||
|
page_size,
|
||||||
|
sm_scale,
|
||||||
|
q_data_type,
|
||||||
|
kv_data_type,):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
kv_len_arr_cur_device = kv_len_arr.to(device)
|
||||||
|
wrapper.plan(qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_len_arr_cur_device,
|
||||||
|
num_heads,
|
||||||
|
head_dim_ckv,
|
||||||
|
head_dim_kpe,
|
||||||
|
page_size,
|
||||||
|
sm_scale,
|
||||||
|
q_data_type,
|
||||||
|
kv_data_type,)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
max_batch_size = 1
|
||||||
|
max_pages = 1
|
||||||
|
page_size = 64
|
||||||
|
num_heads = 128
|
||||||
|
|
||||||
|
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
|
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
wrapper = MLAWrapperSingleton.get_instance(
|
||||||
|
"cuda",
|
||||||
|
max_batch_size,
|
||||||
|
max_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
wrapper.plan(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
kv_len_arr,
|
||||||
|
128,
|
||||||
|
512,
|
||||||
|
64,
|
||||||
|
page_size,
|
||||||
|
192 ** (-0.5),
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||||
|
|
||||||
|
k = (
|
||||||
|
torch.cat([ckv, k_pe], dim=-1)
|
||||||
|
.view(-1, 1, 512 + 64)
|
||||||
|
.repeat_interleave(num_heads, dim=1)
|
||||||
|
)
|
||||||
|
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||||
|
|
||||||
|
print(k[:10].shape)
|
||||||
|
print(v[:10].shape)
|
||||||
|
|
||||||
|
attn_ref, lse_ref = attention_ref(
|
||||||
|
max_batch_size,
|
||||||
|
torch.cat([q_nope, q_pe], dim=-1),
|
||||||
|
k[:10],
|
||||||
|
v[:10],
|
||||||
|
False,
|
||||||
|
192 ** (-0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||||
|
print("test past")
|
|
@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||||
gguf_loader: GGUFLoader,
|
gguf_loader: GGUFLoader,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module = None,
|
orig_module: nn.Module = None,
|
||||||
generate_device: str = "cuda",
|
|
||||||
prefill_device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
self.generate_device = generate_device
|
self.generate_device = generate_device
|
||||||
self.prefill_device = prefill_device
|
self.prefill_device = prefill_device
|
||||||
|
|
|
@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||||
prefill_op: str| None = "KLinearTorch",
|
prefill_op: str| None = "KLinearTorch",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
# build all the linear operators
|
# build all the linear operators
|
||||||
if prefill_op is not None:
|
if prefill_op is not None:
|
||||||
|
|
|
@ -109,6 +109,7 @@ GGML_TYPES = {
|
||||||
"Q5_K": 13,
|
"Q5_K": 13,
|
||||||
"Q6_K": 14,
|
"Q6_K": 14,
|
||||||
"IQ4_XS": 23,
|
"IQ4_XS": 23,
|
||||||
|
"BF16": 30,
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
||||||
|
@ -116,6 +117,7 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
||||||
GGML_BLOCK_SIZES = {
|
GGML_BLOCK_SIZES = {
|
||||||
"F32": 4,
|
"F32": 4,
|
||||||
"F16": 2,
|
"F16": 2,
|
||||||
|
"BF16": 2,
|
||||||
"Q4_0": 2 + 16,
|
"Q4_0": 2 + 16,
|
||||||
"Q5_0": 2 + 4 + 16,
|
"Q5_0": 2 + 4 + 16,
|
||||||
"Q8_0": 2 + 32,
|
"Q8_0": 2 + 32,
|
||||||
|
@ -130,6 +132,7 @@ GGML_BLOCK_SIZES = {
|
||||||
GGML_ELEMENTS_PER_BLOCK = {
|
GGML_ELEMENTS_PER_BLOCK = {
|
||||||
"F32": 1,
|
"F32": 1,
|
||||||
"F16": 1,
|
"F16": 1,
|
||||||
|
"BF16": 1,
|
||||||
"Q4_0": 32,
|
"Q4_0": 32,
|
||||||
"Q5_0": 32,
|
"Q5_0": 32,
|
||||||
"Q8_0": 32,
|
"Q8_0": 32,
|
||||||
|
@ -333,6 +336,8 @@ class GGUFLoader:
|
||||||
else:
|
else:
|
||||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||||
values = torch.from_numpy(values)
|
values = torch.from_numpy(values)
|
||||||
|
if ggml_name == "BF16":
|
||||||
|
values = values.view(torch.bfloat16)
|
||||||
values = values.view(shape[::-1])
|
values = values.view(shape[::-1])
|
||||||
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
||||||
n_head = self.gguf_file_meta['llama.attention.head_count']
|
n_head = self.gguf_file_meta['llama.attention.head_count']
|
||||||
|
@ -764,6 +769,7 @@ def dequantize_f16_gpu(data, device):
|
||||||
GGML_DEQUANTIZE = {
|
GGML_DEQUANTIZE = {
|
||||||
"F32": dequantize_f32,
|
"F32": dequantize_f32,
|
||||||
"F16": dequantize_f16,
|
"F16": dequantize_f16,
|
||||||
|
"BF16": dequantize_f16,
|
||||||
"Q4_0": dequantize_q4_0,
|
"Q4_0": dequantize_q4_0,
|
||||||
"Q5_0": dequantize_q5_0,
|
"Q5_0": dequantize_q5_0,
|
||||||
"Q8_0": dequantize_q8_0,
|
"Q8_0": dequantize_q8_0,
|
||||||
|
@ -778,6 +784,7 @@ GGML_DEQUANTIZE = {
|
||||||
GGML_DEQUANTIZE_GPU = {
|
GGML_DEQUANTIZE_GPU = {
|
||||||
"F32": dequantize_f32_gpu,
|
"F32": dequantize_f32_gpu,
|
||||||
"F16": dequantize_f16_gpu,
|
"F16": dequantize_f16_gpu,
|
||||||
|
"BF16": dequantize_f16_gpu,
|
||||||
"Q4_0": dequantize_q4_0_gpu,
|
"Q4_0": dequantize_q4_0_gpu,
|
||||||
"Q5_0": dequantize_q5_0_gpu,
|
"Q5_0": dequantize_q5_0_gpu,
|
||||||
"Q8_0": dequantize_q8_0_gpu,
|
"Q8_0": dequantize_q8_0_gpu,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
|
||||||
from ktransformers.models.custom_cache import StaticCache
|
from ktransformers.models.custom_cache import StaticCache
|
||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.util.textstream import TextStreamer
|
from ktransformers.util.textstream import TextStreamer
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
|
@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
||||||
module.load()
|
module.load()
|
||||||
|
|
||||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||||
mode = 'normal', force_think: bool = False):
|
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
|
||||||
|
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||||
import os
|
import os
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
|
@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
|
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||||
generated_ids = torch.zeros(
|
generated_ids = torch.zeros(
|
||||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||||
)
|
)
|
||||||
|
@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
generated_ids[:, seq_length] = next_token
|
generated_ids[:, seq_length] = next_token
|
||||||
tokens.append(int(next_token))
|
tokens.append(int(next_token))
|
||||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
|
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
seq_length += 1
|
seq_length += 1
|
||||||
|
|
||||||
|
@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
warm_uped = True
|
warm_uped = True
|
||||||
cuda_graph_runner = CUDAGraphRunner()
|
cuda_graph_runner = CUDAGraphRunner()
|
||||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||||
|
if i > 1 and use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||||
|
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||||
|
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue