From c176e516b5d0c5be86111f2942992f37bbcdc53d Mon Sep 17 00:00:00 2001 From: Xie Weiyu Date: Mon, 17 Feb 2025 20:40:28 +0800 Subject: [PATCH] server mix mla --- .../backend/interfaces/ktransformers.py | 24 ++++++++++--------- .../server/backend/interfaces/transformers.py | 14 +++++++---- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 4ceb65d..6b8c45a 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -15,7 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device - +warm_uped = False class KTransformersThreadContext(TransformersThreadContext): pass @@ -73,11 +73,13 @@ class KTransformersInterface(TransformersInterface): self._infer_lock = asyncio.Lock() - def decode_one_tokens(self): + def decode_one_tokens(self, i): device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device - if self.args.use_cuda_graph: + global warm_uped + if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): + warm_uped = True if not hasattr(self, "cuda_graph_runner"): self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner.capture( @@ -91,14 +93,14 @@ class KTransformersInterface(TransformersInterface): use_cache=True, ) - if hasattr(self, "cuda_graph_runner"): - logits = self.cuda_graph_runner( - self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position - ) - self.cache.change_seq_length(1) - torch.cuda.synchronize() - logits = logits[0, -1, :] - return self.logits_to_token(logits) + if hasattr(self, "cuda_graph_runner"): + logits = self.cuda_graph_runner( + self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position + ) + self.cache.change_seq_length(1) + torch.cuda.synchronize() + logits = logits[0, -1, :] + return self.logits_to_token(logits) if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(torch_device) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f18581a..d00fc02 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -18,7 +18,7 @@ import sys, os from ..base import ThreadContext, BackendInterfaceBase from ktransformers.server.config.log import logger from ..args import ConfigArgs, default_args - +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py class TextStreamer: @@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase): self.ever_generated_ids.add(last) return last - def decode_one_tokens(self): + def decode_one_tokens(self, i): if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(self.args.device) logits = self.model( @@ -291,9 +291,15 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def generate(self): self.profiler.set_counter("decode", 0) - for _ in range(1, self.args.max_new_tokens): + for i in range(1, self.args.max_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = self.decode_one_tokens() + if i > 1 and flashinfer_enabled: + MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, + sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + next_token = self.decode_one_tokens(i) self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id: assert self.args.batch_size == 1