mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 04:30:03 +00:00
Fix ktransformers-server flashinfer wrapper position arg issue;
Fix db position issue
This commit is contained in:
parent
203b853c75
commit
31677181c3
3 changed files with 7 additions and 2 deletions
|
@ -128,6 +128,9 @@ class MLAWrapper():
|
||||||
if kv_indices is None:
|
if kv_indices is None:
|
||||||
assert self.max_batch_size == 1
|
assert self.max_batch_size == 1
|
||||||
kv_indices = self.kv_indices_buf
|
kv_indices = self.kv_indices_buf
|
||||||
|
if bsz_tensor is None:
|
||||||
|
assert self.max_batch_size == 1
|
||||||
|
bsz_tensor = self.batch_size_tensor_buf
|
||||||
|
|
||||||
self.wrapper.plan(
|
self.wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
|
@ -166,6 +169,7 @@ class MLAWrapperSingleton():
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_len_arr,
|
kv_len_arr,
|
||||||
|
bsz_tensor,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim_ckv,
|
head_dim_ckv,
|
||||||
head_dim_kpe,
|
head_dim_kpe,
|
||||||
|
@ -179,6 +183,7 @@ class MLAWrapperSingleton():
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_len_arr_cur_device,
|
kv_len_arr_cur_device,
|
||||||
|
bsz_tensor,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim_ckv,
|
head_dim_ckv,
|
||||||
head_dim_kpe,
|
head_dim_kpe,
|
||||||
|
|
|
@ -341,7 +341,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
for i in range(1, self.max_new_tokens):
|
for i in range(1, self.max_new_tokens):
|
||||||
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||||
if flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
|
||||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||||
|
|
|
@ -75,7 +75,7 @@ class Config(metaclass=Singleton):
|
||||||
# db configs
|
# db configs
|
||||||
self.db_configs: dict = cfg.get("db", {})
|
self.db_configs: dict = cfg.get("db", {})
|
||||||
self.db_type = self.db_configs.get("type", "")
|
self.db_type = self.db_configs.get("type", "")
|
||||||
self.db_host = Config.to_path(self.db_configs.get("host", ""))
|
self.db_host = self.localstore_path
|
||||||
self.db_port = self.db_configs.get("port", "")
|
self.db_port = self.db_configs.get("port", "")
|
||||||
self.db_name = self.db_configs.get("database", "")
|
self.db_name = self.db_configs.get("database", "")
|
||||||
self.db_pool_size = self.db_configs.get("pool_size")
|
self.db_pool_size = self.db_configs.get("pool_size")
|
||||||
|
|
Loading…
Add table
Reference in a new issue