Merge pull request #657 from kvcache-ai/feat-absorb-for-long-prefill

Feat absorb for long prefill
This commit is contained in:
Atream 2025-02-25 16:53:21 +08:00 committed by GitHub
commit b443c7dfa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 193 additions and 43 deletions

View file

@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
@ -64,7 +64,6 @@ def local_chat(
force_think: bool = False, force_think: bool = False,
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer Config().cpu_infer = cpu_infer
@ -169,7 +168,7 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim

View file

@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
from typing import Optional, Tuple from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import get_compute_capability
import logging import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
prefill_device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
chunck_size: int = 1000, chunck_size: int = 1000,
absorb_for_prefill: bool = False,
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
self.orig_module.__init__(orig_module.config, self.orig_module.__init__(orig_module.config,
orig_module.layer_idx) orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically. self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
self.mla_wrapper = None self.mla_wrapper = None
self.absorb_for_prefill = absorb_for_prefill
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2) q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous() #assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode # decode
if q_len == 1: if q_len == 1 or self.absorb_for_prefill:
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2) q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous() q_nope = q_nope.contiguous()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope.squeeze_(1) q_nope.squeeze_(0)
q_pe.squeeze_(1) q_pe.squeeze_(0)
# flash attn doesn't support head_dim bigger than 256, use flashinfer # flash attn doesn't support head_dim bigger than 256, use flashinfer
if self.mla_wrapper is None: if self.mla_wrapper is None:
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True) self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
if self.mla_wrapper.need_plan: if self.mla_wrapper.need_plan:
self.mla_wrapper.need_plan = False self.mla_wrapper.need_plan = False
if q_len == 1:
self.mla_wrapper.plan(None,None,None, self.mla_wrapper.plan(None,None,None,
position_ids.squeeze(1)+1, position_ids.squeeze(1)+1,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
past_key_value.page_size, past_key_value.page_size,
self.softmax_scale, self.softmax_scale,
q_nope.dtype, q_nope.dtype,
compressed_kv.dtype) compressed_kv.dtype)
else:
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
self.mla_wrapper.plan(qo_indptr,None,None,
kv_len_arr,
self.num_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
past_key_value.page_size,
self.softmax_scale,
q_nope.dtype,
compressed_kv.dtype)
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
""" """
@ -441,10 +459,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt': if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
attention_mask, attention_mask,

View file

@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof( down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
) )
# print(self.gate_qtype, self.up_qtype, self.down_qtype) #print(self.gate_type, self.up_type, self.down_type)
n_routed_experts = self.n_routed_experts n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module) # n_routed_experts = len(self.orig_module)
moe_config = MOEConfig( moe_config = MOEConfig(
@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.cat(self.gate, dim=0) self.up = torch.cat(self.up, dim=0)
self.gate = torch.cat(self.gate, dim=0) self.gate = torch.cat(self.gate, dim=0)
self.down = torch.cat(self.gate, dim=0) self.down = torch.cat(self.down, dim=0)
return return
def unload(self): def unload(self):

View file

@ -9,7 +9,7 @@ flashinfer_enabled = False
try: try:
import flashinfer import flashinfer
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable flashinfer_enabled = True
print("found flashinfer") print("found flashinfer")
except ImportError: except ImportError:
@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,
page_size, page_size,
False, # causal is False for decoding True, # causal
sm_scale, sm_scale,
q_data_type, q_data_type,
kv_data_type, kv_data_type,
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton(): class MLAWrapperSingleton():
wrappers:dict = {} wrappers:dict = {}
@ -179,6 +179,17 @@ class MLAWrapperSingleton():
sm_scale, sm_scale,
q_data_type, q_data_type,
kv_data_type,) kv_data_type,)
wrapper.need_plan = False
@classmethod
def need_plan_all(cls):
for device, wrapper in cls.wrappers.items():
wrapper.need_plan = True
@classmethod
def reset_buffer(cls):
for device, wrapper in cls.wrappers.items():
wrapper.qo_indptr_buf[1] = 1
if __name__ == "__main__": if __name__ == "__main__":
@ -187,8 +198,9 @@ if __name__ == "__main__":
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_len = 10
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda") k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
@ -199,10 +211,10 @@ if __name__ == "__main__":
max_pages, max_pages,
) )
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda") kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan( wrapper.plan(
None, qo_indptr,
None, None,
None, None,
kv_len_arr, kv_len_arr,
@ -216,6 +228,7 @@ if __name__ == "__main__":
) )
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape)
k = ( k = (
torch.cat([ckv, k_pe], dim=-1) torch.cat([ckv, k_pe], dim=-1)
@ -235,6 +248,7 @@ if __name__ == "__main__":
False, False,
192 ** (-0.5) 192 ** (-0.5)
) )
print(attn_ref.shape)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past") print("test past")

View file

@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import ( from ktransformers.models.modeling_llama import (
@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
causal_mask = self._update_causal_mask( if os.name == 'nt' or get_compute_capability()<8:
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions print("for Windows or GPU before ampere, use forward_windows")
) # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = None
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds

View file

@ -60,6 +60,7 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match: - match:
name: "^model$" name: "^model$"
replace: replace:

View file

@ -0,0 +1,86 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False warm_uped = False
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache: if self.use_static_cache:
logits = self.model( logits = self.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device) self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)

View file

@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.args.max_new_tokens): for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
if i > 1 and flashinfer_enabled: if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,

View file

@ -330,6 +330,8 @@ class GGUFLoader:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values.copy()) values = torch.from_numpy(values.copy())
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[-2::-1]) values = values.view(shape[-2::-1])
return values return values

View file

@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped = False warm_uped = False
def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
min_compute_capability_major = 100
for gpu_id in range(num_gpus):
gpu_props = torch.cuda.get_device_properties(gpu_id)
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
def set_module(model, submodule_key, module): def set_module(model, submodule_key, module):
tokens = submodule_key.split('.') tokens = submodule_key.split('.')
sub_tokens = tokens[:-1] sub_tokens = tokens[:-1]
@ -164,6 +176,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else: else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.need_plan_all()
logits = model( logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device) )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
@ -186,6 +201,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else: else:
next_token = torch.argmax(next_token_scores, dim=-1) next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time first_token_time = time.time() - start_time
if use_flashinfer_mla:
MLAWrapperSingleton.reset_buffer()
prefill_count = seq_length prefill_count = seq_length
prefill_time = first_token_time prefill_time = first_token_time
@ -203,22 +221,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time = time.time() start_time = time.time()
for i in range(1, max_new_tokens): for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
global warm_uped global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True warm_uped = True
cuda_graph_runner = CUDAGraphRunner() cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
if i > 1 and use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int() generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token)) tokens.append(int(next_token))
seq_length += 1 seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>': if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True) print(stream.end(), end="", flush=True)
break break
else: else: