Merge branch 'main' into temperature_top_p_from_request

This commit is contained in:
wang jiahao 2025-02-27 18:08:55 +08:00 committed by GitHub
commit 26f7b4af11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 1573 additions and 159 deletions

View file

@ -15,6 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False
@ -35,9 +36,9 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
optimize_rule_path = default_optimize_rules[config.architectures[0]]
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = args.optimize_config_path
optimize_config_path = args.optimize_config_path
# print(optimize_config)
@ -47,7 +48,7 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
@ -199,6 +202,9 @@ class KTransformersInterface(TransformersInterface):
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)

View file

@ -337,14 +337,14 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
if i > 1 and flashinfer_enabled:
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id:
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)