mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
support absorb for prefill long context
This commit is contained in:
parent
e9b1216a9a
commit
f4c198bd42
8 changed files with 93 additions and 33 deletions
|
@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
from ktransformers.util.utils import prefill_and_generate
|
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ def local_chat(
|
||||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||||
|
|
||||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
|
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
|
||||||
generated = prefill_and_generate(
|
generated = prefill_and_generate(
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
|
from ktransformers.util.utils import get_compute_capability
|
||||||
import logging
|
import logging
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
prefill_device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
generate_device: str = "cuda",
|
generate_device: str = "cuda",
|
||||||
chunck_size: int = 1000,
|
chunck_size: int = 1000,
|
||||||
|
absorb_for_prefill: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
self.orig_module.__init__(orig_module.config,
|
self.orig_module.__init__(orig_module.config,
|
||||||
orig_module.layer_idx)
|
orig_module.layer_idx)
|
||||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||||
self.mla_wrapper = None
|
self.mla_wrapper = None
|
||||||
|
self.absorb_for_prefill = absorb_for_prefill
|
||||||
|
|
||||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||||
|
@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||||
q_nope = q_nope.transpose(1, 2)
|
q_nope = q_nope.transpose(1, 2)
|
||||||
assert q_nope.is_contiguous()
|
#assert q_nope.is_contiguous()
|
||||||
|
|
||||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||||
|
@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
if q_len == 1:
|
if q_len == 1 or self.absorb_for_prefill:
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
|
@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||||
q_nope = q_nope.transpose(1, 2)
|
q_nope = q_nope.transpose(1, 2)
|
||||||
assert q_nope.is_contiguous()
|
q_nope = q_nope.contiguous()
|
||||||
|
#assert q_nope.is_contiguous()
|
||||||
|
|
||||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||||
q_nope.squeeze_(1)
|
q_nope.squeeze_(0)
|
||||||
q_pe.squeeze_(1)
|
q_pe.squeeze_(0)
|
||||||
|
|
||||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||||
if self.mla_wrapper is None:
|
if self.mla_wrapper is None:
|
||||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||||
if self.mla_wrapper.need_plan:
|
if self.mla_wrapper.need_plan:
|
||||||
self.mla_wrapper.need_plan = False
|
self.mla_wrapper.need_plan = False
|
||||||
|
if q_len == 1:
|
||||||
self.mla_wrapper.plan(None,None,None,
|
self.mla_wrapper.plan(None,None,None,
|
||||||
position_ids.squeeze(1)+1,
|
position_ids.squeeze(1)+1,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.qk_rope_head_dim,
|
self.qk_rope_head_dim,
|
||||||
past_key_value.page_size,
|
past_key_value.page_size,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
q_nope.dtype,
|
q_nope.dtype,
|
||||||
compressed_kv.dtype)
|
compressed_kv.dtype)
|
||||||
|
else:
|
||||||
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
|
||||||
|
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
|
||||||
|
self.mla_wrapper.plan(qo_indptr,None,None,
|
||||||
|
kv_len_arr,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
past_key_value.page_size,
|
||||||
|
self.softmax_scale,
|
||||||
|
q_nope.dtype,
|
||||||
|
compressed_kv.dtype)
|
||||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -443,6 +461,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||||
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
||||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
|
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if os.name == 'nt':
|
if os.name == 'nt' or get_compute_capability()<8:
|
||||||
|
print("for Windows or GPU before ampere, use forward_windows")
|
||||||
return self.forward_windows(
|
return self.forward_windows(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
|
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flashinfer
|
import flashinfer
|
||||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
flashinfer_enabled = True
|
||||||
print("found flashinfer")
|
print("found flashinfer")
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -132,14 +132,14 @@ class MLAWrapper():
|
||||||
head_dim_ckv,
|
head_dim_ckv,
|
||||||
head_dim_kpe,
|
head_dim_kpe,
|
||||||
page_size,
|
page_size,
|
||||||
False, # causal is False for decoding
|
True, # causal
|
||||||
sm_scale,
|
sm_scale,
|
||||||
q_data_type,
|
q_data_type,
|
||||||
kv_data_type,
|
kv_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||||
|
|
||||||
class MLAWrapperSingleton():
|
class MLAWrapperSingleton():
|
||||||
wrappers:dict = {}
|
wrappers:dict = {}
|
||||||
|
@ -179,6 +179,17 @@ class MLAWrapperSingleton():
|
||||||
sm_scale,
|
sm_scale,
|
||||||
q_data_type,
|
q_data_type,
|
||||||
kv_data_type,)
|
kv_data_type,)
|
||||||
|
wrapper.need_plan = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def need_plan_all(cls):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.need_plan = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_buffer(cls):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.qo_indptr_buf[1] = 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -187,8 +198,9 @@ if __name__ == "__main__":
|
||||||
page_size = 64
|
page_size = 64
|
||||||
num_heads = 128
|
num_heads = 128
|
||||||
|
|
||||||
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
q_len = 10
|
||||||
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
@ -199,10 +211,10 @@ if __name__ == "__main__":
|
||||||
max_pages,
|
max_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
None,
|
qo_indptr,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
kv_len_arr,
|
kv_len_arr,
|
||||||
|
@ -216,6 +228,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||||
|
print(attn_output.shape)
|
||||||
|
|
||||||
k = (
|
k = (
|
||||||
torch.cat([ckv, k_pe], dim=-1)
|
torch.cat([ckv, k_pe], dim=-1)
|
||||||
|
@ -235,6 +248,7 @@ if __name__ == "__main__":
|
||||||
False,
|
False,
|
||||||
192 ** (-0.5)
|
192 ** (-0.5)
|
||||||
)
|
)
|
||||||
|
print(attn_ref.shape)
|
||||||
|
|
||||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||||
print("test past")
|
print("test past")
|
|
@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
|
||||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||||
from ktransformers.models.configuration_llama import LlamaConfig
|
from ktransformers.models.configuration_llama import LlamaConfig
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.utils import InferenceState
|
from ktransformers.util.utils import InferenceState, get_compute_capability
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from ktransformers.models.modeling_llama import (
|
from ktransformers.models.modeling_llama import (
|
||||||
|
@ -649,7 +649,9 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
else:
|
else:
|
||||||
if os.name == 'nt':
|
if os.name == 'nt' or get_compute_capability()<8:
|
||||||
|
print("for Windows or GPU before ampere, use forward_windows")
|
||||||
|
# only use mask in forward windows or can't flash attn
|
||||||
causal_mask = self._update_causal_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,6 +60,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
- match:
|
- match:
|
||||||
name: "^model$"
|
name: "^model$"
|
||||||
replace:
|
replace:
|
||||||
|
|
|
@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
|
||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
|
||||||
input_ids = input_ids.to("cpu")
|
input_ids = input_ids.to("cpu")
|
||||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
if flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.need_plan_all()
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
|
||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
|
if flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
for i in range(1, self.args.max_new_tokens):
|
for i in range(1, self.args.max_new_tokens):
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
if i > 1 and flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||||
|
|
|
@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
|
def get_compute_capability(device:torch.device = None):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if device is None:
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
min_compute_capability_major = 100
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
gpu_props = torch.cuda.get_device_properties(gpu_id)
|
||||||
|
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
|
||||||
|
return min_compute_capability_major
|
||||||
|
else:
|
||||||
|
return torch.cuda.get_device_properties(device)
|
||||||
|
|
||||||
def set_module(model, submodule_key, module):
|
def set_module(model, submodule_key, module):
|
||||||
tokens = submodule_key.split('.')
|
tokens = submodule_key.split('.')
|
||||||
sub_tokens = tokens[:-1]
|
sub_tokens = tokens[:-1]
|
||||||
|
@ -153,6 +165,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||||
else:
|
else:
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.need_plan_all()
|
||||||
|
|
||||||
logits = model(
|
logits = model(
|
||||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||||
|
@ -175,6 +190,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
else:
|
else:
|
||||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||||
first_token_time = time.time() - start_time
|
first_token_time = time.time() - start_time
|
||||||
|
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.reset_buffer()
|
||||||
|
|
||||||
prefill_count = seq_length
|
prefill_count = seq_length
|
||||||
prefill_time = first_token_time
|
prefill_time = first_token_time
|
||||||
|
@ -192,15 +210,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(1, max_new_tokens):
|
for i in range(1, max_new_tokens):
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||||
|
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||||
|
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||||
global warm_uped
|
global warm_uped
|
||||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||||
warm_uped = True
|
warm_uped = True
|
||||||
cuda_graph_runner = CUDAGraphRunner()
|
cuda_graph_runner = CUDAGraphRunner()
|
||||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||||
if i > 1 and use_flashinfer_mla:
|
|
||||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
|
||||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
|
||||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
|
||||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
|
|
Loading…
Add table
Reference in a new issue