mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
Merge pull request #657 from kvcache-ai/feat-absorb-for-long-prefill
Feat absorb for long prefill
This commit is contained in:
commit
b443c7dfa2
11 changed files with 193 additions and 43 deletions
|
@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
|||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
|
||||
|
@ -64,7 +64,6 @@ def local_chat(
|
|||
force_think: bool = False,
|
||||
):
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
Config().cpu_infer = cpu_infer
|
||||
|
@ -169,7 +168,7 @@ def local_chat(
|
|||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
|
|||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import get_compute_capability
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
|
@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
absorb_for_prefill: bool = False,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
self.mla_wrapper = None
|
||||
self.absorb_for_prefill = absorb_for_prefill
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
|
@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
#assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
|
@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if q_len == 1 or self.absorb_for_prefill:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
|
@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
q_nope = q_nope.contiguous()
|
||||
#assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
q_nope.squeeze_(1)
|
||||
q_pe.squeeze_(1)
|
||||
q_nope.squeeze_(0)
|
||||
q_pe.squeeze_(0)
|
||||
|
||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||
if self.mla_wrapper is None:
|
||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
if q_len == 1:
|
||||
self.mla_wrapper.plan(None,None,None,
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
else:
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
|
||||
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
|
||||
self.mla_wrapper.plan(qo_indptr,None,None,
|
||||
kv_len_arr,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
"""
|
||||
|
@ -441,10 +459,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
||||
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt':
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
|
@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
#print(self.gate_type, self.up_type, self.down_type)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
|
@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
|
|||
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
|
||||
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
|
||||
|
||||
self.up = torch.cat(self.gate, dim=0)
|
||||
self.up = torch.cat(self.up, dim=0)
|
||||
self.gate = torch.cat(self.gate, dim=0)
|
||||
self.down = torch.cat(self.gate, dim=0)
|
||||
self.down = torch.cat(self.down, dim=0)
|
||||
return
|
||||
|
||||
def unload(self):
|
||||
|
|
|
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
|||
|
||||
try:
|
||||
import flashinfer
|
||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
||||
flashinfer_enabled = True
|
||||
print("found flashinfer")
|
||||
|
||||
except ImportError:
|
||||
|
@ -132,14 +132,14 @@ class MLAWrapper():
|
|||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
False, # causal is False for decoding
|
||||
True, # causal
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
wrappers:dict = {}
|
||||
|
@ -179,6 +179,17 @@ class MLAWrapperSingleton():
|
|||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,)
|
||||
wrapper.need_plan = False
|
||||
|
||||
@classmethod
|
||||
def need_plan_all(cls):
|
||||
for device, wrapper in cls.wrappers.items():
|
||||
wrapper.need_plan = True
|
||||
|
||||
@classmethod
|
||||
def reset_buffer(cls):
|
||||
for device, wrapper in cls.wrappers.items():
|
||||
wrapper.qo_indptr_buf[1] = 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -187,8 +198,9 @@ if __name__ == "__main__":
|
|||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_len = 10
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
@ -199,10 +211,10 @@ if __name__ == "__main__":
|
|||
max_pages,
|
||||
)
|
||||
|
||||
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
||||
|
||||
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
None,
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
|
@ -216,6 +228,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
|
@ -235,6 +248,7 @@ if __name__ == "__main__":
|
|||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
|
@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
|
|||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.util.utils import InferenceState, get_compute_capability
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.models.modeling_llama import (
|
||||
|
@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||
if per_layer_prefill_flag:
|
||||
causal_mask = None
|
||||
else:
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
# only use mask in forward windows or can't flash attn
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
|
|
@ -60,6 +60,7 @@
|
|||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
|
|
86
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
Normal file
86
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
Normal file
|
@ -0,0 +1,86 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
|
||||
#- match:
|
||||
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
# kwargs:
|
||||
# prefill_device: "cuda"
|
||||
# prefill_op: "KExpertsTorch"
|
||||
# generate_device: "cuda"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
|
@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
for i in range(1, self.args.max_new_tokens):
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
if i > 1 and flashinfer_enabled:
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
|
|
|
@ -330,6 +330,8 @@ class GGUFLoader:
|
|||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values.copy())
|
||||
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
values = values.view(shape[-2::-1])
|
||||
|
||||
return values
|
||||
|
|
|
@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
|||
|
||||
warm_uped = False
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
min_compute_capability_major = 100
|
||||
for gpu_id in range(num_gpus):
|
||||
gpu_props = torch.cuda.get_device_properties(gpu_id)
|
||||
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
|
||||
return min_compute_capability_major
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device)
|
||||
|
||||
def set_module(model, submodule_key, module):
|
||||
tokens = submodule_key.split('.')
|
||||
sub_tokens = tokens[:-1]
|
||||
|
@ -164,6 +176,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
|
||||
logits = model(
|
||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
@ -187,6 +202,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
first_token_time = time.time() - start_time
|
||||
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
|
||||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
if force_think:
|
||||
|
@ -203,22 +221,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
start_time = time.time()
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
if i > 1 and use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue