mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
Fix ktransformers-server flashinfer wrapper position arg issue;
Fix db position issue
This commit is contained in:
parent
203b853c75
commit
31677181c3
3 changed files with 7 additions and 2 deletions
|
@ -128,6 +128,9 @@ class MLAWrapper():
|
|||
if kv_indices is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indices = self.kv_indices_buf
|
||||
if bsz_tensor is None:
|
||||
assert self.max_batch_size == 1
|
||||
bsz_tensor = self.batch_size_tensor_buf
|
||||
|
||||
self.wrapper.plan(
|
||||
qo_indptr,
|
||||
|
@ -166,6 +169,7 @@ class MLAWrapperSingleton():
|
|||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
|
@ -179,6 +183,7 @@ class MLAWrapperSingleton():
|
|||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr_cur_device,
|
||||
bsz_tensor,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
|
|
|
@ -341,7 +341,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
for i in range(1, self.max_new_tokens):
|
||||
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
|
|
@ -75,7 +75,7 @@ class Config(metaclass=Singleton):
|
|||
# db configs
|
||||
self.db_configs: dict = cfg.get("db", {})
|
||||
self.db_type = self.db_configs.get("type", "")
|
||||
self.db_host = Config.to_path(self.db_configs.get("host", ""))
|
||||
self.db_host = self.localstore_path
|
||||
self.db_port = self.db_configs.get("port", "")
|
||||
self.db_name = self.db_configs.get("database", "")
|
||||
self.db_pool_size = self.db_configs.get("pool_size")
|
||||
|
|
Loading…
Add table
Reference in a new issue