server mix mla

This commit is contained in:
Xie Weiyu 2025-02-17 20:40:28 +08:00
parent 038bc30888
commit c176e516b5
2 changed files with 23 additions and 15 deletions

View file

@ -15,7 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext):
pass
@ -73,11 +73,13 @@ class KTransformersInterface(TransformersInterface):
self._infer_lock = asyncio.Lock()
def decode_one_tokens(self):
def decode_one_tokens(self, i):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph:
global warm_uped
if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(

View file

@ -18,7 +18,7 @@ import sys, os
from ..base import ThreadContext, BackendInterfaceBase
from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:
@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase):
self.ever_generated_ids.add(last)
return last
def decode_one_tokens(self):
def decode_one_tokens(self, i):
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model(
@ -291,9 +291,15 @@ class TransformersInterface(BackendInterfaceBase):
@torch.no_grad
def generate(self):
self.profiler.set_counter("decode", 0)
for _ in range(1, self.args.max_new_tokens):
for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = self.decode_one_tokens()
if i > 1 and flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens(i)
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id:
assert self.args.batch_size == 1