diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 7cbac7c..c6c9c2e 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -110,15 +110,15 @@ def local_chat( optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) try: - model.generation_config = GenerationConfig.from_pretrained(model_path) - except: - gen_config = GenerationConfig( - max_length=128, - temperature=0.7, - top_p=0.9, - do_sample=True - ) - model.generation_config = gen_config + model.generation_config = GenerationConfig.from_pretrained(model_path) + except Exception as e: + print(f"generation config can't auto create, make default. Message: {e}") + gen_config = GenerationConfig( + temperature=0.6, + top_p=0.95, + do_sample=True + ) + model.generation_config = gen_config # model.generation_config = GenerationConfig.from_pretrained(model_path) if model.generation_config.pad_token_id is None: model.generation_config.pad_token_id = model.generation_config.eos_token_id diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 35c8093..25b1359 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - def forward_linux_flashinfer( + def forward_linux_flashinfer_chunk( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + def forward_linux_flashinfer( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + if q_len <= self.chunck_size or not self.absorb_for_prefill: + return self.forward_linux_flashinfer_chunk( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + + assert False + + def forward_windows( self, hidden_states: torch.Tensor, diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index f8ea3ce..2bec5cc 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -122,7 +122,7 @@ class MLAWrapper(): if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf - + self.wrapper.plan( qo_indptr, kv_indptr, @@ -139,6 +139,11 @@ class MLAWrapper(): ) def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): + #print("run") + #print(self.wrapper._qo_indptr_buf) + #print(self.wrapper._kv_indptr_buf) + #print(self.wrapper._kv_indices_buf) + #print(self.wrapper._kv_len_arr_buf) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) class MLAWrapperSingleton(): @@ -201,11 +206,12 @@ class MLAWrapperSingleton(): if __name__ == "__main__": max_batch_size = 1 - max_pages = 1 + max_pages = 128 page_size = 64 num_heads = 128 - q_len = 10 + kv_len = 2069 + q_len = 1 q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") @@ -218,7 +224,7 @@ if __name__ == "__main__": max_pages, ) - kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") + kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") wrapper.plan( qo_indptr, @@ -244,15 +250,15 @@ if __name__ == "__main__": ) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) - print(k[:10].shape) - print(v[:10].shape) + print(k[:kv_len].shape) + print(v[:kv_len].shape) attn_ref, lse_ref = attention_ref( max_batch_size, torch.cat([q_nope, q_pe], dim=-1), - k[:10], - v[:10], - False, + k[:kv_len], + v[:kv_len], + True, 192 ** (-0.5) ) print(attn_ref.shape) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 3f5ad8e..03afefa 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -184,8 +184,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) generation_config, model_kwargs = model._prepare_generation_config( - None, max_length=max_new_tokens, - do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config + None, do_sample=True + # change this to modify generate config + #top_k=5, top_p=0.85, temperature=0.1 ) try: # transformers==4.43 logits_warper = (