use generation config from json file in official repo

This commit is contained in:
Atream 2025-02-27 11:48:34 +00:00
parent 9660b2cc1e
commit e645d84794
4 changed files with 57 additions and 21 deletions

View file

@ -110,15 +110,15 @@ def local_chat(
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try: try:
model.generation_config = GenerationConfig.from_pretrained(model_path) model.generation_config = GenerationConfig.from_pretrained(model_path)
except: except Exception as e:
gen_config = GenerationConfig( print(f"generation config can't auto create, make default. Message: {e}")
max_length=128, gen_config = GenerationConfig(
temperature=0.7, temperature=0.6,
top_p=0.9, top_p=0.95,
do_sample=True do_sample=True
) )
model.generation_config = gen_config model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path) # model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None: if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id

View file

@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux_flashinfer( def forward_linux_flashinfer_chunk(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size or not self.absorb_for_prefill:
return self.forward_linux_flashinfer_chunk(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
assert False
def forward_windows( def forward_windows(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View file

@ -139,6 +139,11 @@ class MLAWrapper():
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton(): class MLAWrapperSingleton():
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
if __name__ == "__main__": if __name__ == "__main__":
max_batch_size = 1 max_batch_size = 1
max_pages = 1 max_pages = 128
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
q_len = 10 kv_len = 2069
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
@ -218,7 +224,7 @@ if __name__ == "__main__":
max_pages, max_pages,
) )
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan( wrapper.plan(
qo_indptr, qo_indptr,
@ -244,15 +250,15 @@ if __name__ == "__main__":
) )
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:10].shape) print(k[:kv_len].shape)
print(v[:10].shape) print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref( attn_ref, lse_ref = attention_ref(
max_batch_size, max_batch_size,
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
k[:10], k[:kv_len],
v[:10], v[:kv_len],
False, True,
192 ** (-0.5) 192 ** (-0.5)
) )
print(attn_ref.shape) print(attn_ref.shape)

View file

@ -184,8 +184,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device) )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
generation_config, model_kwargs = model._prepare_generation_config( generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens, None, do_sample=True
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config # change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
) )
try: # transformers==4.43 try: # transformers==4.43
logits_warper = ( logits_warper = (