mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
ca1dc1e7d1
94 changed files with 4366 additions and 703 deletions
|
@ -14,7 +14,10 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
from typing import Optional
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
pass
|
||||
|
@ -23,19 +26,29 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
generation_config = GenerationConfig(
|
||||
max_length=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.temperature,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
with torch.device("meta"):
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
optimize_config_path = args.optimize_config_path
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
|
@ -45,8 +58,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
self.model.generation_config = generation_config
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
self.cache = StaticCache(
|
||||
|
@ -57,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
dtype=self.model.dtype,
|
||||
)
|
||||
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
||||
try:
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
gen_config = GenerationConfig(
|
||||
max_length=128,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True
|
||||
)
|
||||
self.model.generation_config = gen_config
|
||||
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
@ -74,10 +78,13 @@ class KTransformersInterface(TransformersInterface):
|
|||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_one_tokens(self):
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
torch.cuda.set_device(torch_device)
|
||||
if warm_uped and self.args.use_cuda_graph:
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
|
@ -99,14 +106,15 @@ class KTransformersInterface(TransformersInterface):
|
|||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
if self.args.use_cuda_graph:
|
||||
warm_uped = True
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids.to(torch_device),
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
|
@ -119,55 +127,98 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
self.seq_length = input_ids_length
|
||||
return
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
else:
|
||||
logger.warning(f"seq_length bigger than cache_lens, killed")
|
||||
exit(0)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1, self.seq_length)).to(device)
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
def chunk_prefill(input_ids, cache_position):
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
return logits
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < input_ids_length:
|
||||
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
|
||||
if self.cache != None:
|
||||
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
|
||||
chunk_start += self.args.chunk_prefill_size
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
@ -176,7 +227,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
return torch.tensor([self.seq_length - 1], device=device)
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
|
||||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id):
|
||||
yield v
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||
yield v
|
||||
|
|
|
@ -19,7 +19,7 @@ import sys, os
|
|||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
from ktransformers.server.config.log import logger
|
||||
from ..args import ConfigArgs, default_args
|
||||
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
class TextStreamer:
|
||||
|
@ -171,7 +171,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += m["content"]
|
||||
new_messages[-1]["content"] += '\n' + m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
|
@ -180,7 +180,11 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
|
@ -199,14 +203,31 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature if self.args.temperature!=0 else logits
|
||||
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
if temperature is None or temperature == 0:
|
||||
temperature = self.model.generation_config.temperature
|
||||
if top_p is None:
|
||||
top_p = self.model.generation_config.top_p
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config, device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
logits[token_idx] *= self.args.repetition_penalty
|
||||
else:
|
||||
logits[token_idx] /= self.args.repetition_penalty
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
|
@ -222,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
|
||||
def decode_one_tokens(self):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
|
@ -238,38 +257,57 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
return self.logits_to_token(logits)
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
device = input_ids.device
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
|
@ -281,22 +319,34 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@torch.no_grad
|
||||
def generate(self):
|
||||
self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length)
|
||||
if(self.args.max_new_tokens <= 0):
|
||||
logger.warning("max_new_tokens is less than 0")
|
||||
yield self.streamer.end()
|
||||
return
|
||||
logger.info(f"max_new_tokens: {self.args.max_new_tokens}")
|
||||
self.profiler.set_counter("decode", 0)
|
||||
for _ in range(1, self.args.max_new_tokens):
|
||||
|
||||
for i in range(1, self.args.max_new_tokens):
|
||||
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
@ -315,7 +365,8 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
|
@ -325,8 +376,9 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
@ -334,11 +386,14 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
if Config().user_force_think:
|
||||
t = "<think>\n"
|
||||
print(t,end="",flush=True)
|
||||
yield t
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue