From 31677181c3f59c93a08deabeea0846c21cc4b3da Mon Sep 17 00:00:00 2001 From: Azure-Tang Date: Tue, 1 Apr 2025 07:30:23 +0000 Subject: [PATCH] Fix ktransformers-server flashinfer wrapper position arg issue; Fix db position issue --- ktransformers/operators/flashinfer_wrapper.py | 5 +++++ ktransformers/server/backend/interfaces/transformers.py | 2 +- ktransformers/server/config/config.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index 5700d65..08ec341 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -128,6 +128,9 @@ class MLAWrapper(): if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf + if bsz_tensor is None: + assert self.max_batch_size == 1 + bsz_tensor = self.batch_size_tensor_buf self.wrapper.plan( qo_indptr, @@ -166,6 +169,7 @@ class MLAWrapperSingleton(): kv_indptr, kv_indices, kv_len_arr, + bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, @@ -179,6 +183,7 @@ class MLAWrapperSingleton(): kv_indptr, kv_indices, kv_len_arr_cur_device, + bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index c7ac80f..1460176 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -341,7 +341,7 @@ class TransformersInterface(BackendInterfaceBase): for i in range(1, self.max_new_tokens): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): if flashinfer_enabled: - MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, + MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index e5cbafc..5ea2443 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -75,7 +75,7 @@ class Config(metaclass=Singleton): # db configs self.db_configs: dict = cfg.get("db", {}) self.db_type = self.db_configs.get("type", "") - self.db_host = Config.to_path(self.db_configs.get("host", "")) + self.db_host = self.localstore_path self.db_port = self.db_configs.get("port", "") self.db_name = self.db_configs.get("database", "") self.db_pool_size = self.db_configs.get("pool_size")