mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
Merge branch 'main' into temperature_top_p_from_request
This commit is contained in:
commit
26f7b4af11
54 changed files with 1573 additions and 159 deletions
|
@ -15,6 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
|||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
from typing import Optional
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
@ -35,9 +36,9 @@ class KTransformersInterface(TransformersInterface):
|
|||
with torch.device("meta"):
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
optimize_config_path = args.optimize_config_path
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
|
@ -47,7 +48,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
|
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -199,6 +202,9 @@ class KTransformersInterface(TransformersInterface):
|
|||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue