mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 05:29:29 +00:00
Implement multi-batch support for v2, v3, and r1 models with backend_type configured as ktransformers.
This commit is contained in:
parent
890b0f1622
commit
a6ab9e349c
6 changed files with 383 additions and 52 deletions
|
@ -58,7 +58,11 @@ class StaticCache(transformers.StaticCache):
|
||||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||||
|
if multi_batch_enabled:
|
||||||
|
latent_shape = (max_batch_size, self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||||
|
else:
|
||||||
|
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||||
self.kv_lora_rank = config.kv_lora_rank
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
# TODO: support real page table
|
# TODO: support real page table
|
||||||
|
@ -143,8 +147,14 @@ class StaticCache(transformers.StaticCache):
|
||||||
page_idx = cache_position // self.page_size
|
page_idx = cache_position // self.page_size
|
||||||
page_offset = cache_position % self.page_size
|
page_offset = cache_position % self.page_size
|
||||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
if multi_batch_enabled:
|
||||||
|
batch_size = key_states.size(0)
|
||||||
|
k_out[:batch_size, page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||||
|
k_out[:batch_size, page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||||
|
else:
|
||||||
|
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||||
|
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||||
return k_out, self.page_table_list[layer_idx]
|
return k_out, self.page_table_list[layer_idx]
|
||||||
else:
|
else:
|
||||||
k_out[:, :, cache_position] = key_states
|
k_out[:, :, cache_position] = key_states
|
||||||
|
|
|
@ -693,6 +693,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
return self.forward_xpu(
|
return self.forward_xpu(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -707,7 +708,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
elif (os.name == 'nt'
|
elif (os.name == 'nt'
|
||||||
or get_compute_capability() < 8
|
or get_compute_capability() < 8
|
||||||
or hidden_states.device.type == 'cpu'
|
or hidden_states.device.type == 'cpu'
|
||||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
or device_manager.gpu_vendor != GPUVendor.NVIDIA
|
||||||
|
or multi_batch_enabled):
|
||||||
return self.forward_windows(
|
return self.forward_windows(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
|
@ -670,6 +670,7 @@ class KLinearMarlin(KLinearBase):
|
||||||
padding_input[:,:self.orin_in_features] = x
|
padding_input[:,:self.orin_in_features] = x
|
||||||
x = padding_input
|
x = padding_input
|
||||||
marlin_s = self.marlin_s.to(x.dtype)
|
marlin_s = self.marlin_s.to(x.dtype)
|
||||||
|
x = x.contiguous()
|
||||||
x = KTransformersOps.gptq_marlin_gemm(
|
x = KTransformersOps.gptq_marlin_gemm(
|
||||||
x,
|
x,
|
||||||
self.marlin_q_w,
|
self.marlin_q_w,
|
||||||
|
|
|
@ -669,10 +669,12 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
else:
|
else:
|
||||||
|
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
|
||||||
if (os.name == 'nt'
|
if (os.name == 'nt'
|
||||||
or get_compute_capability() < 8
|
or get_compute_capability() < 8
|
||||||
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
|
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
|
||||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
or device_manager.gpu_vendor != GPUVendor.NVIDIA
|
||||||
|
or multi_batch_enabled):
|
||||||
# print("for Windows or GPU before ampere, use forward_windows")
|
# print("for Windows or GPU before ampere, use forward_windows")
|
||||||
# only use mask in forward windows or can't flash attn
|
# only use mask in forward windows or can't flash attn
|
||||||
causal_mask = self._update_causal_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
|
|
|
@ -7,8 +7,8 @@ from ktransformers.server.backend.interfaces.transformers import (
|
||||||
ConfigArgs,
|
ConfigArgs,
|
||||||
TransformersThreadContext,
|
TransformersThreadContext,
|
||||||
default_args,
|
default_args,
|
||||||
TextStreamer,
|
|
||||||
)
|
)
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||||
from ktransformers.models.custom_cache import StaticCache
|
from ktransformers.models.custom_cache import StaticCache
|
||||||
|
@ -18,12 +18,115 @@ from ktransformers.util.utils import get_device
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
|
from torch.nn.attention import SDPBackend
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
multi_batch_enabled = False
|
||||||
|
|
||||||
class KTransformersThreadContext(TransformersThreadContext):
|
class KTransformersThreadContext(TransformersThreadContext):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class MultiBatchTextStreamer:
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.skip_prompt = skip_prompt
|
||||||
|
self.decode_kwargs = decode_kwargs
|
||||||
|
|
||||||
|
# variables used in the streaming process for each batch
|
||||||
|
self.token_caches = {} # {batch_index: [tokens]}
|
||||||
|
self.print_lens = {} # {batch_index: print_len}
|
||||||
|
self.next_tokens_are_prompt = {} # {batch_index: bool}
|
||||||
|
|
||||||
|
def reset(self, batch_index: int = 0):
|
||||||
|
self.token_caches[batch_index] = []
|
||||||
|
self.print_lens[batch_index] = 0
|
||||||
|
self.next_tokens_are_prompt[batch_index] = True
|
||||||
|
|
||||||
|
def reset_all(self):
|
||||||
|
self.token_caches.clear()
|
||||||
|
self.print_lens.clear()
|
||||||
|
self.next_tokens_are_prompt.clear()
|
||||||
|
|
||||||
|
def put(self, value, batch_index: int = 0) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Receives tokens for a specific batch, decodes them, and returns printable text.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, int):
|
||||||
|
raise ValueError("MultiBatchTextStreamer only supports int type input")
|
||||||
|
|
||||||
|
# Initialize batch if not exists
|
||||||
|
if batch_index not in self.token_caches:
|
||||||
|
self.reset(batch_index)
|
||||||
|
|
||||||
|
if self.skip_prompt and self.next_tokens_are_prompt[batch_index]:
|
||||||
|
self.next_tokens_are_prompt[batch_index] = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the new token to the cache and decodes the entire thing.
|
||||||
|
self.token_caches[batch_index].append(value)
|
||||||
|
text = self.tokenizer.decode(self.token_caches[batch_index], skip_special_tokens=True, **self.decode_kwargs)
|
||||||
|
|
||||||
|
# After the symbol for a new line, we flush the cache.
|
||||||
|
if text.endswith("\n"):
|
||||||
|
printable_text = text[self.print_lens[batch_index] :]
|
||||||
|
self.reset(batch_index)
|
||||||
|
# If the last token is a CJK character, we print the characters.
|
||||||
|
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
||||||
|
printable_text = text[self.print_lens[batch_index] :]
|
||||||
|
self.print_lens[batch_index] += len(printable_text)
|
||||||
|
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||||
|
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||||
|
else:
|
||||||
|
printable_text = text[self.print_lens[batch_index] : text.rfind(" ") + 1]
|
||||||
|
self.print_lens[batch_index] += len(printable_text)
|
||||||
|
return printable_text
|
||||||
|
|
||||||
|
def end(self, batch_index: int = 0) -> Optional[str]:
|
||||||
|
"""Flushes any remaining cache for a specific batch and returns printable text."""
|
||||||
|
if batch_index not in self.token_caches:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Flush the cache, if it exists
|
||||||
|
if len(self.token_caches[batch_index]) > 0:
|
||||||
|
text = self.tokenizer.decode(self.token_caches[batch_index], skip_special_tokens=True, **self.decode_kwargs)
|
||||||
|
printable_text = text[self.print_lens[batch_index] :]
|
||||||
|
self.reset(batch_index)
|
||||||
|
else:
|
||||||
|
printable_text = ""
|
||||||
|
|
||||||
|
self.next_tokens_are_prompt[batch_index] = True
|
||||||
|
return printable_text
|
||||||
|
|
||||||
|
def end_all(self) -> List[Optional[str]]:
|
||||||
|
"""Flushes all batches and returns a list of printable texts."""
|
||||||
|
results = []
|
||||||
|
for batch_index in sorted(self.token_caches.keys()):
|
||||||
|
results.append(self.end(batch_index))
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if (
|
||||||
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
|
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||||
|
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||||
|
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||||
|
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||||
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||||
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||||
|
): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
class KTransformersInterface(TransformersInterface):
|
class KTransformersInterface(TransformersInterface):
|
||||||
def __init__(self, args: ConfigArgs = default_args):
|
def __init__(self, args: ConfigArgs = default_args):
|
||||||
|
@ -40,7 +143,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
do_sample=True
|
do_sample=True
|
||||||
)
|
)
|
||||||
|
self.tokenizer.pad_token_id = 0
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
config._attn_implementation = "flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
@ -64,6 +167,9 @@ class KTransformersInterface(TransformersInterface):
|
||||||
self.model.generation_config = generation_config
|
self.model.generation_config = generation_config
|
||||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||||
|
if args.batch_size > 1:
|
||||||
|
global multi_batch_enabled
|
||||||
|
multi_batch_enabled = True
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
max_batch_size=args.batch_size,
|
max_batch_size=args.batch_size,
|
||||||
|
@ -75,9 +181,14 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
if self.model.generation_config.pad_token_id is None:
|
if self.model.generation_config.pad_token_id is None:
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = MultiBatchTextStreamer(self.tokenizer)
|
||||||
|
|
||||||
self._infer_lock = asyncio.Lock()
|
self._infer_lock = asyncio.Lock()
|
||||||
|
self._inference_queue = asyncio.Queue()
|
||||||
|
self._batch_worker_task = None
|
||||||
|
|
||||||
|
def append_new_tokens(self, new_tokens: int, batch_idx: int) -> Optional[str]:
|
||||||
|
self.generated_ids[batch_idx, self.seq_length] = new_tokens
|
||||||
|
return self.streamer.put(new_tokens, batch_idx)
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
global warm_uped
|
global warm_uped
|
||||||
|
@ -98,6 +209,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
main_device=torch_device,
|
main_device=torch_device,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
attention_mask=self.attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(self, "cuda_graph_runner"):
|
if hasattr(self, "cuda_graph_runner"):
|
||||||
|
@ -106,8 +218,12 @@ class KTransformersInterface(TransformersInterface):
|
||||||
)
|
)
|
||||||
self.cache.change_seq_length(1)
|
self.cache.change_seq_length(1)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
logits = logits[0, -1, :]
|
tokens=[]
|
||||||
return self.logits_to_token(logits)
|
for batch_idx in range(logits.size(0)):
|
||||||
|
logit = logits[batch_idx, -1, :] # [batch_size, vocab_size]
|
||||||
|
tokens.append(self.logits_to_token(logit))
|
||||||
|
self.update_mask(tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
if self.args.use_cuda_graph:
|
if self.args.use_cuda_graph:
|
||||||
warm_uped = True
|
warm_uped = True
|
||||||
|
@ -119,14 +235,17 @@ class KTransformersInterface(TransformersInterface):
|
||||||
past_key_values=self.cache,
|
past_key_values=self.cache,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
attention_mask=self.attention_mask,
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||||
logits = logits[0, -1, :]
|
|
||||||
|
|
||||||
return self.logits_to_token(logits)
|
|
||||||
|
|
||||||
|
|
||||||
|
tokens=[]
|
||||||
|
for batch_idx in range(logits.size(0)):
|
||||||
|
logit = logits[batch_idx, -1, :] # [batch_size, vocab_size]
|
||||||
|
tokens.append(self.logits_to_token(logit))
|
||||||
|
self.update_mask(tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
|
@ -147,25 +266,27 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
self.ever_generated_ids.clear()
|
self.ever_generated_ids.clear()
|
||||||
same_prefix = 0
|
|
||||||
flat_input_ids = input_ids.flatten()
|
|
||||||
|
|
||||||
if getattr(self, 'generated_ids', None) is None:
|
self.generated_ids = torch.zeros(
|
||||||
self.generated_ids = torch.zeros(
|
input_ids.size(0),
|
||||||
self.args.batch_size,
|
input_ids.shape[-1] + max_new_tokens + 1,
|
||||||
input_ids.shape[-1] + max_new_tokens + 1,
|
dtype=torch.int,
|
||||||
dtype=torch.int,
|
device=self.args.device,
|
||||||
device=self.args.device,
|
)
|
||||||
)
|
self.seq_length = 1
|
||||||
self.seq_length = 1
|
|
||||||
|
same_prefix = self.seq_length
|
||||||
flat_prev_ids = self.generated_ids.flatten()
|
for i in range(input_ids.size(0)):
|
||||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
cur_same_prefix = 0
|
||||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
flat_input_ids = input_ids[i].flatten()
|
||||||
same_prefix += 1
|
flat_prev_ids = self.generated_ids[i].flatten()
|
||||||
else:
|
for j in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||||
break
|
if flat_input_ids[j] == flat_prev_ids[j]:
|
||||||
|
cur_same_prefix += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
same_prefix = min(same_prefix, cur_same_prefix)
|
||||||
|
|
||||||
logger.debug(f"same prefix len: {same_prefix}")
|
logger.debug(f"same prefix len: {same_prefix}")
|
||||||
self.cache.remove_suffix(same_prefix)
|
self.cache.remove_suffix(same_prefix)
|
||||||
self.seq_length = same_prefix
|
self.seq_length = same_prefix
|
||||||
|
@ -174,7 +295,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
|
||||||
self.ever_generated_ids.clear()
|
self.ever_generated_ids.clear()
|
||||||
self.profiler.set_counter("prefill", input_ids_length)
|
self.profiler.set_counter("prefill", input_ids.numel())
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||||
|
|
||||||
|
@ -184,7 +305,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||||
if delta_length > 0:
|
if delta_length > 0:
|
||||||
new_generate_ids = torch.zeros(
|
new_generate_ids = torch.zeros(
|
||||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
input_ids.size(0), delta_length, dtype=torch.int, device=self.args.device
|
||||||
)
|
)
|
||||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||||
else:
|
else:
|
||||||
|
@ -210,6 +331,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
past_key_values=self.cache,
|
past_key_values=self.cache,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
attention_mask=self.attention_mask,
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
@ -227,25 +349,218 @@ class KTransformersInterface(TransformersInterface):
|
||||||
if flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.reset_buffer()
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
||||||
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
next_tokens=[]
|
||||||
yield self.append_new_tokens(next_token)
|
for batch_idx in range(input_ids.size(0)):
|
||||||
|
next_token = self.logits_to_token(logits[batch_idx, -1, :])
|
||||||
|
yield self.append_new_tokens(next_token, batch_idx), batch_idx
|
||||||
|
next_tokens.append(next_token)
|
||||||
|
self.seq_length += 1
|
||||||
|
self.update_mask(next_tokens)
|
||||||
|
|
||||||
|
def update_mask(self, new_tokens):
|
||||||
|
batch_size, seq_length = self.attention_mask.shape
|
||||||
|
|
||||||
|
new_tokens_tensor = torch.tensor(new_tokens, device=self.attention_mask.device)
|
||||||
|
new_mask_col = torch.ones(batch_size, 1, device=self.attention_mask.device)
|
||||||
|
|
||||||
|
if self.tokenizer.eos_token_id is not None:
|
||||||
|
eos_mask = (new_tokens_tensor == self.tokenizer.eos_token_id)
|
||||||
|
new_mask_col[eos_mask] = 0
|
||||||
|
|
||||||
|
if self.tokenizer.pad_token_id is not None:
|
||||||
|
pad_mask = (new_tokens_tensor == self.tokenizer.pad_token_id)
|
||||||
|
new_mask_col[pad_mask] = 0
|
||||||
|
|
||||||
|
self.attention_mask = torch.cat([self.attention_mask, new_mask_col], dim=1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_cache_position(self):
|
def active_cache_position(self):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def generate(self, request_contexts: list = []):
|
||||||
|
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||||
|
if(self.max_new_tokens <= 0):
|
||||||
|
logger.warning("max_new_tokens is less than 0")
|
||||||
|
yield self.streamer.end_all(), "length"
|
||||||
|
return
|
||||||
|
self.profiler.set_counter("decode", 0)
|
||||||
|
|
||||||
|
for i in range(1, self.max_new_tokens):
|
||||||
|
if all(context['is_completed'] for context in request_contexts):
|
||||||
|
break
|
||||||
|
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||||
|
if flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
|
||||||
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||||
|
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||||
|
next_tokens = self.decode_one_tokens()
|
||||||
|
for batch_idx in range(len(next_tokens)):
|
||||||
|
if request_contexts[batch_idx]['is_completed'] == True:
|
||||||
|
continue
|
||||||
|
next_token = next_tokens[batch_idx]
|
||||||
|
self.profiler.inc("decode")
|
||||||
|
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||||
|
yield self.streamer.end(batch_idx), None, batch_idx
|
||||||
|
yield "", "stop", batch_idx
|
||||||
|
# assert self.args.batch_size == 1
|
||||||
|
request_contexts[batch_idx]['is_completed'] = True
|
||||||
|
continue
|
||||||
|
yield self.append_new_tokens(next_token, batch_idx), None, batch_idx
|
||||||
|
self.seq_length += 1
|
||||||
|
else: # for's else, if output get max new tokens
|
||||||
|
yield self.streamer.end_all(), None, 0
|
||||||
|
yield "", "length", 0
|
||||||
|
|
||||||
|
async def _batch_worker(self):
|
||||||
|
while True:
|
||||||
|
batch = []
|
||||||
|
for _ in range(self.args.batch_size):
|
||||||
|
try:
|
||||||
|
item = await asyncio.wait_for(self._inference_queue.get(), timeout=0.001)
|
||||||
|
batch.append(item)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug("Timeout waiting for a single request")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not batch:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
continue
|
||||||
|
logger.info(f"Collected {len(batch)} requests, starting to process batch")
|
||||||
|
|
||||||
|
batch_data = {
|
||||||
|
'messages': [item.get('local_messages', []) for item in batch],
|
||||||
|
'thread_ids': [item.get('thread_id', '') for item in batch],
|
||||||
|
'temperatures': [item.get('temperature', None) for item in batch],
|
||||||
|
'top_ps': [item.get('top_p', None) for item in batch],
|
||||||
|
'max_tokens': [item.get('max_tokens', None) for item in batch],
|
||||||
|
'max_completion_tokens': [item.get('max_completion_tokens', None) for item in batch]
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async def process_batch():
|
||||||
|
async for token, finish_reason, index in self.batch_inference(
|
||||||
|
batch_data['messages'],
|
||||||
|
batch_data['thread_ids'],
|
||||||
|
batch_data['temperatures'],
|
||||||
|
batch_data['top_ps'],
|
||||||
|
batch_data['max_tokens'],
|
||||||
|
batch_data['max_completion_tokens']
|
||||||
|
):
|
||||||
|
await batch[index]['result_queue'].put((token, finish_reason))
|
||||||
|
await process_batch()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in batch inference: {str(e)}")
|
||||||
|
for item in batch:
|
||||||
|
await item['result_queue'].put(("ERROR", str(e)))
|
||||||
|
finally:
|
||||||
|
for item in batch:
|
||||||
|
await item['result_queue'].put((None, None))
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
async with self._infer_lock:
|
result_queue = asyncio.Queue()
|
||||||
async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
|
await self._inference_queue.put({
|
||||||
yield v
|
'local_messages': local_messages,
|
||||||
|
'thread_id': thread_id,
|
||||||
# return this inference raw usage
|
'temperature': temperature,
|
||||||
yield RawUsage(
|
'top_p': top_p,
|
||||||
tokenize_time = self.profiler.get_timer_sec('tokenize'),
|
'max_tokens': max_tokens,
|
||||||
prefill_time = self.profiler.get_timer_sec('prefill'),
|
'max_completion_tokens': max_completion_tokens,
|
||||||
decode_time = self.profiler.get_timer_sec('decode'),
|
'result_queue': result_queue
|
||||||
prefill_count = self.profiler.get_counter('prefill'),
|
})
|
||||||
decode_count = self.profiler.get_counter('decode'),
|
|
||||||
)
|
if self._batch_worker_task is None:
|
||||||
|
self._batch_worker_task = asyncio.create_task(self._batch_worker())
|
||||||
|
while True:
|
||||||
|
token, finish_reason = await result_queue.get()
|
||||||
|
if token is None:
|
||||||
|
break
|
||||||
|
yield token, finish_reason
|
||||||
|
yield RawUsage(
|
||||||
|
tokenize_time = self.profiler.get_timer_sec('tokenize'),
|
||||||
|
prefill_time = self.profiler.get_timer_sec('prefill'),
|
||||||
|
decode_time = self.profiler.get_timer_sec('decode'),
|
||||||
|
prefill_count = self.profiler.get_counter('prefill'),
|
||||||
|
decode_count = self.profiler.get_counter('decode'),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def batch_inference(self, batch_messages: List[List], thread_ids: List[str], temperatures: List[Optional[float]], top_ps: List[Optional[float]], max_tokens_list: List[Optional[float]], max_completion_tokens_list: List[Optional[float]]):
|
||||||
|
self.streamer.reset()
|
||||||
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
|
print("SJF batch_messages len is ", len(batch_messages))
|
||||||
|
|
||||||
|
input_ids_list = []
|
||||||
|
for i, messages in enumerate(batch_messages):
|
||||||
|
if isinstance(messages, List):
|
||||||
|
input_ids = self.format_and_tokenize_input_ids(thread_ids[i], messages)
|
||||||
|
elif isinstance(messages, str):
|
||||||
|
input_ids = self.tokenize_prompt(messages)
|
||||||
|
else:
|
||||||
|
raise ValueError("local_messages should be List or str")
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
|
||||||
|
max_length = max(ids.size(1) for ids in input_ids_list)
|
||||||
|
padded_input_ids = []
|
||||||
|
for ids in input_ids_list:
|
||||||
|
padding_length = max_length - ids.size(1)
|
||||||
|
if padding_length > 0:
|
||||||
|
padded_ids = torch.cat([ids, torch.full((1, padding_length), self.tokenizer.pad_token_id, device=self.args.device)], dim=1)
|
||||||
|
else:
|
||||||
|
padded_ids = ids
|
||||||
|
padded_input_ids.append(padded_ids)
|
||||||
|
|
||||||
|
combined_input_ids = torch.cat(padded_input_ids, dim=0) # [batch_size, seq_len]
|
||||||
|
self.attention_mask = (combined_input_ids != self.tokenizer.pad_token_id).int()
|
||||||
|
|
||||||
|
if Config().user_force_think:
|
||||||
|
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||||
|
input_ids = torch.cat(
|
||||||
|
[input_ids, token_thinks], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
|
self.profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
|
if Config().user_force_think:
|
||||||
|
think = '<think>\n'
|
||||||
|
print(think, end="",flush=True)
|
||||||
|
yield think, None
|
||||||
|
|
||||||
|
for t, batch_idx in self.prefill(
|
||||||
|
combined_input_ids,
|
||||||
|
True, # is_new
|
||||||
|
temperatures[0] if temperatures else None,
|
||||||
|
top_ps[0] if top_ps else None,
|
||||||
|
max_tokens_list[0] if max_tokens_list else None,
|
||||||
|
max_completion_tokens_list[0] if max_completion_tokens_list else None,
|
||||||
|
):
|
||||||
|
# output think token after prefill done
|
||||||
|
if t is not None:
|
||||||
|
print(t, end="",flush=True)
|
||||||
|
yield t, None, batch_idx
|
||||||
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
|
self.profiler.create_and_start_timer("decode")
|
||||||
|
request_contexts = []
|
||||||
|
for i in range(len(batch_messages)):
|
||||||
|
context = {
|
||||||
|
'is_completed': False,
|
||||||
|
}
|
||||||
|
request_contexts.append(context)
|
||||||
|
|
||||||
|
self.profiler.create_and_start_timer("decode")
|
||||||
|
|
||||||
|
for t, finish_reason, batch_idx in self.generate(request_contexts):
|
||||||
|
if t is not None:
|
||||||
|
if multi_batch_enabled:
|
||||||
|
print(f"Inference result: batch_idx={batch_idx}, token={t}", flush=True)
|
||||||
|
else:
|
||||||
|
print(t, end="",flush=True)
|
||||||
|
yield t, finish_reason, batch_idx
|
||||||
|
print("")
|
||||||
|
self.profiler.pause_timer("decode")
|
||||||
|
self.report_last_time_performance()
|
|
@ -141,6 +141,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
# thread_related
|
# thread_related
|
||||||
last_request_id: Optional[str] = None
|
last_request_id: Optional[str] = None
|
||||||
ever_generated_ids: Set[int] = set()
|
ever_generated_ids: Set[int] = set()
|
||||||
|
attention_mask: torch.Tensor
|
||||||
|
|
||||||
def __init__(self, args: ConfigArgs = default_args):
|
def __init__(self, args: ConfigArgs = default_args):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
Loading…
Add table
Reference in a new issue