mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
Merge branch 'fix_precision_MLA' of https://github.com/kvcache-ai/ktransformers into server-prefix-cache
This commit is contained in:
commit
73d072f609
3 changed files with 14 additions and 4 deletions
|
@ -2,6 +2,8 @@
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# clear build dirs
|
# clear build dirs
|
||||||
|
rm -rf build
|
||||||
|
rm -rf *.egg-info
|
||||||
rm -rf ktransformers/ktransformers_ext/build
|
rm -rf ktransformers/ktransformers_ext/build
|
||||||
rm -rf ktransformers/ktransformers_ext/cuda/build
|
rm -rf ktransformers/ktransformers_ext/cuda/build
|
||||||
rm -rf ktransformers/ktransformers_ext/cuda/dist
|
rm -rf ktransformers/ktransformers_ext/cuda/dist
|
||||||
|
|
|
@ -104,7 +104,10 @@ class KTransformersInterface(TransformersInterface):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
logits = logits[0, -1, :]
|
logits = logits[0, -1, :]
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
|
if self.args.use_cuda_graph:
|
||||||
|
warm_uped = True
|
||||||
|
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
|
@ -118,7 +121,6 @@ class KTransformersInterface(TransformersInterface):
|
||||||
else:
|
else:
|
||||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||||
logits = logits[0, -1, :]
|
logits = logits[0, -1, :]
|
||||||
warm_uped = True
|
|
||||||
|
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import sys, os
|
||||||
from ..base import ThreadContext, BackendInterfaceBase
|
from ..base import ThreadContext, BackendInterfaceBase
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
from ..args import ConfigArgs, default_args
|
from ..args import ConfigArgs, default_args
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
|
||||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||||
class TextStreamer:
|
class TextStreamer:
|
||||||
|
@ -330,8 +330,14 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def generate(self):
|
def generate(self):
|
||||||
self.profiler.set_counter("decode", 0)
|
self.profiler.set_counter("decode", 0)
|
||||||
for _ in range(1, self.args.max_new_tokens):
|
for i in range(1, self.args.max_new_tokens):
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
|
if i > 1 and flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||||
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||||
|
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||||
next_token = self.decode_one_tokens()
|
next_token = self.decode_one_tokens()
|
||||||
self.profiler.inc("decode")
|
self.profiler.inc("decode")
|
||||||
if next_token == self.tokenizer.eos_token_id:
|
if next_token == self.tokenizer.eos_token_id:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue