kvcache-ai-ktransformers/ktransformers/server/backend/interfaces/transformers.py
2025-07-22 10:58:25 +00:00

599 lines
25 KiB
Python

from typing import Any, List, Optional, Set
import re
import json
import uuid
from transformers import (
LlamaTokenizer,
AutoTokenizer,
AutoConfig,
LlamaForCausalLM,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
TypicalLogitsWarper,
EpsilonLogitsWarper,
EtaLogitsWarper,
)
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
from torch.nn.attention import SDPBackend
import torch
import sys, os
from ..base import ThreadContext, BackendInterfaceBase
from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import torch.distributed as dist
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def reset(self):
self.token_cache = []
self.print_len = 0
def put(self, value) -> Optional[str]:
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if not isinstance(value, int):
raise ValueError("TextStreamer only supports batch size 1, and int type input")
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return None
# Add the new token to the cache and decodes the entire thing.
self.token_cache.append(value)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.reset()
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
return printable_text
def end(self) -> Optional[str]:
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.reset()
else:
printable_text = ""
self.next_tokens_are_prompt = True
return printable_text
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
class TransformersThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
class TransformersInterface(BackendInterfaceBase):
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self.cache = StaticCache(
config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=args.device,
dtype=self.model.dtype,
)
# logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self.streamer = TextStreamer(self.tokenizer)
@property
def current_ids(self):
return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
@property
def active_cache_position(self):
return torch.tensor([self.seq_length - 1], device=self.args.device)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
if not use_torch_npu:
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
else:
logger.debug(f"new_messages: {new_messages}")
input_ids = self.tokenizer.apply_chat_template(
new_messages, add_generation_prompt=True, return_tensors="pt"
)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
# We can only hope that the input_ids are the same
unequal_mask = torch.ne(x,y)
unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item()
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:]
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length += 1
if use_torch_npu:
self.cache.position[0] = self.seq_length
return self.streamer.put(new_tokens)
@staticmethod
def tf_logits_warper(generation_config):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
if temperature is None or temperature == 0:
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
if top_p == 0:
top_p = 0.0001
if use_torch_npu:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, do_sample=True,
top_p=top_p, temperature=temperature
)
else:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.logits_warper = self.tf_logits_warper(generation_config)
def logits_to_token(self, logits: torch.Tensor):
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
probs = torch.nn.functional.softmax(logits, dim=-1)
sample = True
if sample:
last = torch.multinomial(probs, num_samples=1)
else:
_, last = torch.topk(probs, k=1, dim=-1)
last = last.item()
self.ever_generated_ids.add(last)
return last
def decode_one_tokens(self):
if self.use_static_cache:
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(self.current_ids, return_dict=False)[0]
logits = logits[0, -1, :]
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}")
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if is_new:
self.ever_generated_ids.clear()
same_prefix = 0
flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if use_torch_npu:
device = self.args.device
else:
device = input_ids.device
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token)
@torch.no_grad
def generate(self):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end(), "length"
return
self.profiler.set_counter("decode", 0)
for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token), None
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
if use_torch_npu and self.args.use_cuda_graph:
utils._USE_NPU_GRAPH = False
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.destroy()
def check_is_new(self, thread_id: str):
if not self.use_static_cache:
return True
if self.last_request_id is None:
self.last_request_id = thread_id
return True
else:
if self.last_request_id == thread_id:
return False
else:
self.last_request_id = thread_id
return True
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
rank = torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if tp_size == world_size and tp_size > 1:
torch.distributed.barrier()
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
dist.all_gather(all_input_sizes, input_size)
max_input_size = max([size.item() for size in all_input_sizes])
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
dist.all_gather(all_padded_inputs, padded_input_ids)
original_size = all_input_sizes[0].item()
input_ids = all_padded_inputs[0][:, :original_size]
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
if tp_size == world_size and rank != 0:
pass
else:
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
if tp_size == world_size and rank != 0:
pass
else:
print(t, end="",flush=True)
yield t, finish_reason
if tp_size == world_size and rank != 0:
pass
else:
self.profiler.pause_timer("decode")
self.report_last_time_performance()
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
yield tok
return
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()