mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Merge pull request #719 from kvcache-ai/fix-use-generation-json
use generation config from json file in official repo
This commit is contained in:
commit
85e2cc7bf4
4 changed files with 57 additions and 21 deletions
|
@ -110,15 +110,15 @@ def local_chat(
|
||||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
except:
|
except Exception as e:
|
||||||
gen_config = GenerationConfig(
|
print(f"generation config can't auto create, make default. Message: {e}")
|
||||||
max_length=128,
|
gen_config = GenerationConfig(
|
||||||
temperature=0.7,
|
temperature=0.6,
|
||||||
top_p=0.9,
|
top_p=0.95,
|
||||||
do_sample=True
|
do_sample=True
|
||||||
)
|
)
|
||||||
model.generation_config = gen_config
|
model.generation_config = gen_config
|
||||||
# model.generation_config = GenerationConfig.from_pretrained(model_path)
|
# model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
if model.generation_config.pad_token_id is None:
|
if model.generation_config.pad_token_id is None:
|
||||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||||
|
|
|
@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
def forward_linux_flashinfer(
|
def forward_linux_flashinfer_chunk(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
def forward_linux_flashinfer(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if q_len <= self.chunck_size or not self.absorb_for_prefill:
|
||||||
|
return self.forward_linux_flashinfer_chunk(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
def forward_windows(
|
def forward_windows(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
@ -122,7 +122,7 @@ class MLAWrapper():
|
||||||
if kv_indices is None:
|
if kv_indices is None:
|
||||||
assert self.max_batch_size == 1
|
assert self.max_batch_size == 1
|
||||||
kv_indices = self.kv_indices_buf
|
kv_indices = self.kv_indices_buf
|
||||||
|
|
||||||
self.wrapper.plan(
|
self.wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
|
@ -139,6 +139,11 @@ class MLAWrapper():
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||||
|
#print("run")
|
||||||
|
#print(self.wrapper._qo_indptr_buf)
|
||||||
|
#print(self.wrapper._kv_indptr_buf)
|
||||||
|
#print(self.wrapper._kv_indices_buf)
|
||||||
|
#print(self.wrapper._kv_len_arr_buf)
|
||||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||||
|
|
||||||
class MLAWrapperSingleton():
|
class MLAWrapperSingleton():
|
||||||
|
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
max_batch_size = 1
|
max_batch_size = 1
|
||||||
max_pages = 1
|
max_pages = 128
|
||||||
page_size = 64
|
page_size = 64
|
||||||
num_heads = 128
|
num_heads = 128
|
||||||
|
|
||||||
q_len = 10
|
kv_len = 2069
|
||||||
|
q_len = 1
|
||||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
@ -218,7 +224,7 @@ if __name__ == "__main__":
|
||||||
max_pages,
|
max_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
|
@ -244,15 +250,15 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||||
|
|
||||||
print(k[:10].shape)
|
print(k[:kv_len].shape)
|
||||||
print(v[:10].shape)
|
print(v[:kv_len].shape)
|
||||||
|
|
||||||
attn_ref, lse_ref = attention_ref(
|
attn_ref, lse_ref = attention_ref(
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
torch.cat([q_nope, q_pe], dim=-1),
|
torch.cat([q_nope, q_pe], dim=-1),
|
||||||
k[:10],
|
k[:kv_len],
|
||||||
v[:10],
|
v[:kv_len],
|
||||||
False,
|
True,
|
||||||
192 ** (-0.5)
|
192 ** (-0.5)
|
||||||
)
|
)
|
||||||
print(attn_ref.shape)
|
print(attn_ref.shape)
|
||||||
|
|
|
@ -183,8 +183,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||||
generation_config, model_kwargs = model._prepare_generation_config(
|
generation_config, model_kwargs = model._prepare_generation_config(
|
||||||
None, max_length=max_new_tokens,
|
None, do_sample=True
|
||||||
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
|
# change this to modify generate config
|
||||||
|
#top_k=5, top_p=0.85, temperature=0.1
|
||||||
)
|
)
|
||||||
try: # transformers==4.43
|
try: # transformers==4.43
|
||||||
logits_warper = (
|
logits_warper = (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue