mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
clean PR code and disable flashinfer
This commit is contained in:
parent
cf4da5fd47
commit
a529518346
3 changed files with 13 additions and 23 deletions
|
@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
|
||||
return self.q_absorb, self.out_absorb
|
||||
|
||||
def forward_chunck(
|
||||
self,
|
||||
|
@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
# if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
# del self.orig_module.kv_b_proj
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
|
@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
|||
|
||||
try:
|
||||
import flashinfer
|
||||
flashinfer_enabled = True
|
||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
||||
print("found flashinfer")
|
||||
|
||||
except ImportError:
|
||||
|
|
|
@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue