Fix ktransformers-server flashinfer wrapper position arg issue;

Fix db position issue
This commit is contained in:
Azure-Tang 2025-04-01 07:30:23 +00:00
parent 203b853c75
commit 31677181c3
3 changed files with 7 additions and 2 deletions

View file

@ -128,6 +128,9 @@ class MLAWrapper():
if kv_indices is None: if kv_indices is None:
assert self.max_batch_size == 1 assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf kv_indices = self.kv_indices_buf
if bsz_tensor is None:
assert self.max_batch_size == 1
bsz_tensor = self.batch_size_tensor_buf
self.wrapper.plan( self.wrapper.plan(
qo_indptr, qo_indptr,
@ -166,6 +169,7 @@ class MLAWrapperSingleton():
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr, kv_len_arr,
bsz_tensor,
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,
@ -179,6 +183,7 @@ class MLAWrapperSingleton():
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr_cur_device, kv_len_arr_cur_device,
bsz_tensor,
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,

View file

@ -341,7 +341,7 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.max_new_tokens): for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled: if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)

View file

@ -75,7 +75,7 @@ class Config(metaclass=Singleton):
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "") self.db_type = self.db_configs.get("type", "")
self.db_host = Config.to_path(self.db_configs.get("host", "")) self.db_host = self.localstore_path
self.db_port = self.db_configs.get("port", "") self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "") self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size") self.db_pool_size = self.db_configs.get("pool_size")