mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Merge branch 'main' of https://github.com/KMSorSMS/ktransformers into main
This commit is contained in:
commit
80e0536fb0
20 changed files with 231 additions and 53 deletions
|
@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
from typing import Optional
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
|
@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
|
|||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
generation_config = GenerationConfig(
|
||||
max_length=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.temperature,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
|
||||
self.model.generation_config = generation_config
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
self.cache = StaticCache(
|
||||
|
@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
dtype=self.model.dtype,
|
||||
)
|
||||
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
||||
try:
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
gen_config = GenerationConfig(
|
||||
max_length=128,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True
|
||||
)
|
||||
self.model.generation_config = gen_config
|
||||
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
@ -128,7 +129,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
|
@ -206,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
@ -215,7 +216,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
return torch.tensor([self.seq_length - 1], device=device)
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
|
||||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id):
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||
yield v
|
||||
|
|
|
@ -202,20 +202,23 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def prepare_logits_wrapper(self, inputs, device):
|
||||
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
if temperature is None or temperature == 0:
|
||||
temperature = self.model.generation_config.temperature
|
||||
if top_p is None:
|
||||
top_p = self.model.generation_config.top_p
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=self.args.top_p,
|
||||
temperature=self.args.temperature,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
self.generation_config = generation_config
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config,device=device)
|
||||
self.model._get_logits_warper(generation_config, device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
|
@ -255,7 +258,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
return self.logits_to_token(logits)
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
|
@ -323,7 +326,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
@ -365,7 +368,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
|
@ -392,7 +395,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
print(think, end="",flush=True)
|
||||
yield think
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue