From a6ab9e349cc8b4bf08ca5fd14de3ad9f07894e84 Mon Sep 17 00:00:00 2001 From: jiafei96 Date: Wed, 9 Jul 2025 09:09:47 +0000 Subject: [PATCH] Implement multi-batch support for v2, v3, and r1 models with backend_type configured as ktransformers. --- ktransformers/models/custom_cache.py | 16 +- ktransformers/operators/attention.py | 4 +- ktransformers/operators/linear.py | 1 + ktransformers/operators/models.py | 4 +- .../backend/interfaces/ktransformers.py | 409 ++++++++++++++++-- .../server/backend/interfaces/transformers.py | 1 + 6 files changed, 383 insertions(+), 52 deletions(-) diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 350af73..05ddf65 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -58,7 +58,11 @@ class StaticCache(transformers.StaticCache): # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically self.page_size = 64 self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size - latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) + from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled + if multi_batch_enabled: + latent_shape = (max_batch_size, self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) + else: + latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) self.kv_lora_rank = config.kv_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim # TODO: support real page table @@ -143,8 +147,14 @@ class StaticCache(transformers.StaticCache): page_idx = cache_position // self.page_size page_offset = cache_position % self.page_size # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) - k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states - k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states + from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled + if multi_batch_enabled: + batch_size = key_states.size(0) + k_out[:batch_size, page_idx, page_offset, :, :self.kv_lora_rank] = key_states + k_out[:batch_size, page_idx, page_offset, :, self.kv_lora_rank:] = value_states + else: + k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states + k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states return k_out, self.page_table_list[layer_idx] else: k_out[:, :, cache_position] = key_states diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 9dfdbdc..04968be 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -693,6 +693,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled if torch.xpu.is_available(): return self.forward_xpu( hidden_states, @@ -707,7 +708,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): elif (os.name == 'nt' or get_compute_capability() < 8 or hidden_states.device.type == 'cpu' - or device_manager.gpu_vendor != GPUVendor.NVIDIA): + or device_manager.gpu_vendor != GPUVendor.NVIDIA + or multi_batch_enabled): return self.forward_windows( hidden_states, attention_mask, diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 654c9f9..335c6ff 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -670,6 +670,7 @@ class KLinearMarlin(KLinearBase): padding_input[:,:self.orin_in_features] = x x = padding_input marlin_s = self.marlin_s.to(x.dtype) + x = x.contiguous() x = KTransformersOps.gptq_marlin_gemm( x, self.marlin_q_w, diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index e136b57..39dbaf2 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -669,10 +669,12 @@ class KDeepseekV2Model(BaseInjectedModule): if per_layer_prefill_flag: causal_mask = None else: + from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled if (os.name == 'nt' or get_compute_capability() < 8 or (self.transfer_map is not None and 'cpu' in self.transfer_map.values()) - or device_manager.gpu_vendor != GPUVendor.NVIDIA): + or device_manager.gpu_vendor != GPUVendor.NVIDIA + or multi_batch_enabled): # print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index fd2a808..a1e0860 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -7,8 +7,8 @@ from ktransformers.server.backend.interfaces.transformers import ( ConfigArgs, TransformersThreadContext, default_args, - TextStreamer, ) +from ktransformers.server.config.config import Config from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_cache import StaticCache @@ -18,12 +18,115 @@ from ktransformers.util.utils import get_device from typing import Optional from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.server.schemas.endpoints.chat import RawUsage - +from torch.nn.attention import SDPBackend warm_uped = False +multi_batch_enabled = False class KTransformersThreadContext(TransformersThreadContext): pass +class MultiBatchTextStreamer: + + def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.decode_kwargs = decode_kwargs + + # variables used in the streaming process for each batch + self.token_caches = {} # {batch_index: [tokens]} + self.print_lens = {} # {batch_index: print_len} + self.next_tokens_are_prompt = {} # {batch_index: bool} + + def reset(self, batch_index: int = 0): + self.token_caches[batch_index] = [] + self.print_lens[batch_index] = 0 + self.next_tokens_are_prompt[batch_index] = True + + def reset_all(self): + self.token_caches.clear() + self.print_lens.clear() + self.next_tokens_are_prompt.clear() + + def put(self, value, batch_index: int = 0) -> Optional[str]: + """ + Receives tokens for a specific batch, decodes them, and returns printable text. + """ + if not isinstance(value, int): + raise ValueError("MultiBatchTextStreamer only supports int type input") + + # Initialize batch if not exists + if batch_index not in self.token_caches: + self.reset(batch_index) + + if self.skip_prompt and self.next_tokens_are_prompt[batch_index]: + self.next_tokens_are_prompt[batch_index] = False + return None + + # Add the new token to the cache and decodes the entire thing. + self.token_caches[batch_index].append(value) + text = self.tokenizer.decode(self.token_caches[batch_index], skip_special_tokens=True, **self.decode_kwargs) + + # After the symbol for a new line, we flush the cache. + if text.endswith("\n"): + printable_text = text[self.print_lens[batch_index] :] + self.reset(batch_index) + # If the last token is a CJK character, we print the characters. + elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): + printable_text = text[self.print_lens[batch_index] :] + self.print_lens[batch_index] += len(printable_text) + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, + # which may change with the subsequent token -- there are probably smarter ways to do this!) + else: + printable_text = text[self.print_lens[batch_index] : text.rfind(" ") + 1] + self.print_lens[batch_index] += len(printable_text) + return printable_text + + def end(self, batch_index: int = 0) -> Optional[str]: + """Flushes any remaining cache for a specific batch and returns printable text.""" + if batch_index not in self.token_caches: + return "" + + # Flush the cache, if it exists + if len(self.token_caches[batch_index]) > 0: + text = self.tokenizer.decode(self.token_caches[batch_index], skip_special_tokens=True, **self.decode_kwargs) + printable_text = text[self.print_lens[batch_index] :] + self.reset(batch_index) + else: + printable_text = "" + + self.next_tokens_are_prompt[batch_index] = True + return printable_text + + def end_all(self) -> List[Optional[str]]: + """Flushes all batches and returns a list of printable texts.""" + results = [] + for batch_index in sorted(self.token_caches.keys()): + results.append(self.end(batch_index)) + return results[0] + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False class KTransformersInterface(TransformersInterface): def __init__(self, args: ConfigArgs = default_args): @@ -40,7 +143,7 @@ class KTransformersInterface(TransformersInterface): top_p=args.top_p, do_sample=True ) - + self.tokenizer.pad_token_id = 0 torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -64,6 +167,9 @@ class KTransformersInterface(TransformersInterface): self.model.generation_config = generation_config self.device_map = self.model.gguf_loader.tensor_device_map # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") + if args.batch_size > 1: + global multi_batch_enabled + multi_batch_enabled = True self.cache = StaticCache( config=self.model.config, max_batch_size=args.batch_size, @@ -75,9 +181,14 @@ class KTransformersInterface(TransformersInterface): if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id - self.streamer = TextStreamer(self.tokenizer) - + self.streamer = MultiBatchTextStreamer(self.tokenizer) self._infer_lock = asyncio.Lock() + self._inference_queue = asyncio.Queue() + self._batch_worker_task = None + + def append_new_tokens(self, new_tokens: int, batch_idx: int) -> Optional[str]: + self.generated_ids[batch_idx, self.seq_length] = new_tokens + return self.streamer.put(new_tokens, batch_idx) def decode_one_tokens(self): global warm_uped @@ -98,6 +209,7 @@ class KTransformersInterface(TransformersInterface): main_device=torch_device, return_dict=False, use_cache=True, + attention_mask=self.attention_mask, ) if hasattr(self, "cuda_graph_runner"): @@ -106,8 +218,12 @@ class KTransformersInterface(TransformersInterface): ) self.cache.change_seq_length(1) torch.cuda.synchronize() - logits = logits[0, -1, :] - return self.logits_to_token(logits) + tokens=[] + for batch_idx in range(logits.size(0)): + logit = logits[batch_idx, -1, :] # [batch_size, vocab_size] + tokens.append(self.logits_to_token(logit)) + self.update_mask(tokens) + return tokens if self.args.use_cuda_graph: warm_uped = True @@ -119,14 +235,17 @@ class KTransformersInterface(TransformersInterface): past_key_values=self.cache, return_dict=False, use_cache=True, + attention_mask=self.attention_mask, )[0] else: logits = self.model(self.current_ids, return_dict=False)[0] - logits = logits[0, -1, :] - - return self.logits_to_token(logits) - + tokens=[] + for batch_idx in range(logits.size(0)): + logit = logits[batch_idx, -1, :] # [batch_size, vocab_size] + tokens.append(self.logits_to_token(logit)) + self.update_mask(tokens) + return tokens @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): @@ -147,25 +266,27 @@ class KTransformersInterface(TransformersInterface): if is_new: self.ever_generated_ids.clear() - same_prefix = 0 - flat_input_ids = input_ids.flatten() - if getattr(self, 'generated_ids', None) is None: - self.generated_ids = torch.zeros( - self.args.batch_size, - input_ids.shape[-1] + max_new_tokens + 1, - dtype=torch.int, - device=self.args.device, - ) - self.seq_length = 1 - - flat_prev_ids = self.generated_ids.flatten() - for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): - if flat_input_ids[i] == flat_prev_ids[i]: - same_prefix += 1 - else: - break - + self.generated_ids = torch.zeros( + input_ids.size(0), + input_ids.shape[-1] + max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, + ) + self.seq_length = 1 + + same_prefix = self.seq_length + for i in range(input_ids.size(0)): + cur_same_prefix = 0 + flat_input_ids = input_ids[i].flatten() + flat_prev_ids = self.generated_ids[i].flatten() + for j in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): + if flat_input_ids[j] == flat_prev_ids[j]: + cur_same_prefix += 1 + else: + break + same_prefix = min(same_prefix, cur_same_prefix) + logger.debug(f"same prefix len: {same_prefix}") self.cache.remove_suffix(same_prefix) self.seq_length = same_prefix @@ -174,7 +295,7 @@ class KTransformersInterface(TransformersInterface): input_ids_length = input_ids.shape[-1] self.ever_generated_ids.clear() - self.profiler.set_counter("prefill", input_ids_length) + self.profiler.set_counter("prefill", input_ids.numel()) logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}") @@ -184,7 +305,7 @@ class KTransformersInterface(TransformersInterface): delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + input_ids.size(0), delta_length, dtype=torch.int, device=self.args.device ) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) else: @@ -210,6 +331,7 @@ class KTransformersInterface(TransformersInterface): past_key_values=self.cache, return_dict=False, use_cache=True, + attention_mask=self.attention_mask, )[0] else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] @@ -227,25 +349,218 @@ class KTransformersInterface(TransformersInterface): if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device, temperature, top_p) - next_token = self.logits_to_token(logits[0, -1, :]) - self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 - yield self.append_new_tokens(next_token) + self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 + next_tokens=[] + for batch_idx in range(input_ids.size(0)): + next_token = self.logits_to_token(logits[batch_idx, -1, :]) + yield self.append_new_tokens(next_token, batch_idx), batch_idx + next_tokens.append(next_token) + self.seq_length += 1 + self.update_mask(next_tokens) + + def update_mask(self, new_tokens): + batch_size, seq_length = self.attention_mask.shape + + new_tokens_tensor = torch.tensor(new_tokens, device=self.attention_mask.device) + new_mask_col = torch.ones(batch_size, 1, device=self.attention_mask.device) + + if self.tokenizer.eos_token_id is not None: + eos_mask = (new_tokens_tensor == self.tokenizer.eos_token_id) + new_mask_col[eos_mask] = 0 + if self.tokenizer.pad_token_id is not None: + pad_mask = (new_tokens_tensor == self.tokenizer.pad_token_id) + new_mask_col[pad_mask] = 0 + + self.attention_mask = torch.cat([self.attention_mask, new_mask_col], dim=1) + @property def active_cache_position(self): device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") return torch.tensor([self.seq_length - 1], device=device) - + + @torch.no_grad + def generate(self, request_contexts: list = []): + logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") + if(self.max_new_tokens <= 0): + logger.warning("max_new_tokens is less than 0") + yield self.streamer.end_all(), "length" + return + self.profiler.set_counter("decode", 0) + + for i in range(1, self.max_new_tokens): + if all(context['is_completed'] for context in request_contexts): + break + with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + if flashinfer_enabled: + MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + next_tokens = self.decode_one_tokens() + for batch_idx in range(len(next_tokens)): + if request_contexts[batch_idx]['is_completed'] == True: + continue + next_token = next_tokens[batch_idx] + self.profiler.inc("decode") + if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): + yield self.streamer.end(batch_idx), None, batch_idx + yield "", "stop", batch_idx + # assert self.args.batch_size == 1 + request_contexts[batch_idx]['is_completed'] = True + continue + yield self.append_new_tokens(next_token, batch_idx), None, batch_idx + self.seq_length += 1 + else: # for's else, if output get max new tokens + yield self.streamer.end_all(), None, 0 + yield "", "length", 0 + + async def _batch_worker(self): + while True: + batch = [] + for _ in range(self.args.batch_size): + try: + item = await asyncio.wait_for(self._inference_queue.get(), timeout=0.001) + batch.append(item) + except asyncio.TimeoutError: + logger.debug("Timeout waiting for a single request") + break + + if not batch: + await asyncio.sleep(0.001) + continue + logger.info(f"Collected {len(batch)} requests, starting to process batch") + + batch_data = { + 'messages': [item.get('local_messages', []) for item in batch], + 'thread_ids': [item.get('thread_id', '') for item in batch], + 'temperatures': [item.get('temperature', None) for item in batch], + 'top_ps': [item.get('top_p', None) for item in batch], + 'max_tokens': [item.get('max_tokens', None) for item in batch], + 'max_completion_tokens': [item.get('max_completion_tokens', None) for item in batch] + } + + try: + async def process_batch(): + async for token, finish_reason, index in self.batch_inference( + batch_data['messages'], + batch_data['thread_ids'], + batch_data['temperatures'], + batch_data['top_ps'], + batch_data['max_tokens'], + batch_data['max_completion_tokens'] + ): + await batch[index]['result_queue'].put((token, finish_reason)) + await process_batch() + except Exception as e: + logger.exception(f"Error in batch inference: {str(e)}") + for item in batch: + await item['result_queue'].put(("ERROR", str(e))) + finally: + for item in batch: + await item['result_queue'].put((None, None)) + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): - async with self._infer_lock: - async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens): - yield v - - # return this inference raw usage - yield RawUsage( - tokenize_time = self.profiler.get_timer_sec('tokenize'), - prefill_time = self.profiler.get_timer_sec('prefill'), - decode_time = self.profiler.get_timer_sec('decode'), - prefill_count = self.profiler.get_counter('prefill'), - decode_count = self.profiler.get_counter('decode'), - ) \ No newline at end of file + result_queue = asyncio.Queue() + await self._inference_queue.put({ + 'local_messages': local_messages, + 'thread_id': thread_id, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'max_completion_tokens': max_completion_tokens, + 'result_queue': result_queue + }) + + if self._batch_worker_task is None: + self._batch_worker_task = asyncio.create_task(self._batch_worker()) + while True: + token, finish_reason = await result_queue.get() + if token is None: + break + yield token, finish_reason + yield RawUsage( + tokenize_time = self.profiler.get_timer_sec('tokenize'), + prefill_time = self.profiler.get_timer_sec('prefill'), + decode_time = self.profiler.get_timer_sec('decode'), + prefill_count = self.profiler.get_counter('prefill'), + decode_count = self.profiler.get_counter('decode'), + ) + + async def batch_inference(self, batch_messages: List[List], thread_ids: List[str], temperatures: List[Optional[float]], top_ps: List[Optional[float]], max_tokens_list: List[Optional[float]], max_completion_tokens_list: List[Optional[float]]): + self.streamer.reset() + self.profiler.create_and_start_timer("tokenize") + print("SJF batch_messages len is ", len(batch_messages)) + + input_ids_list = [] + for i, messages in enumerate(batch_messages): + if isinstance(messages, List): + input_ids = self.format_and_tokenize_input_ids(thread_ids[i], messages) + elif isinstance(messages, str): + input_ids = self.tokenize_prompt(messages) + else: + raise ValueError("local_messages should be List or str") + input_ids_list.append(input_ids) + + max_length = max(ids.size(1) for ids in input_ids_list) + padded_input_ids = [] + for ids in input_ids_list: + padding_length = max_length - ids.size(1) + if padding_length > 0: + padded_ids = torch.cat([ids, torch.full((1, padding_length), self.tokenizer.pad_token_id, device=self.args.device)], dim=1) + else: + padded_ids = ids + padded_input_ids.append(padded_ids) + + combined_input_ids = torch.cat(padded_input_ids, dim=0) # [batch_size, seq_len] + self.attention_mask = (combined_input_ids != self.tokenizer.pad_token_id).int() + + if Config().user_force_think: + token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) + input_ids = torch.cat( + [input_ids, token_thinks], dim=1 + ) + + self.profiler.pause_timer("tokenize") + + self.profiler.create_and_start_timer("prefill") + + if Config().user_force_think: + think = '\n' + print(think, end="",flush=True) + yield think, None + + for t, batch_idx in self.prefill( + combined_input_ids, + True, # is_new + temperatures[0] if temperatures else None, + top_ps[0] if top_ps else None, + max_tokens_list[0] if max_tokens_list else None, + max_completion_tokens_list[0] if max_completion_tokens_list else None, + ): + # output think token after prefill done + if t is not None: + print(t, end="",flush=True) + yield t, None, batch_idx + self.profiler.pause_timer("prefill") + + self.profiler.create_and_start_timer("decode") + request_contexts = [] + for i in range(len(batch_messages)): + context = { + 'is_completed': False, + } + request_contexts.append(context) + + self.profiler.create_and_start_timer("decode") + + for t, finish_reason, batch_idx in self.generate(request_contexts): + if t is not None: + if multi_batch_enabled: + print(f"Inference result: batch_idx={batch_idx}, token={t}", flush=True) + else: + print(t, end="",flush=True) + yield t, finish_reason, batch_idx + print("") + self.profiler.pause_timer("decode") + self.report_last_time_performance() \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 78cb73f..77b25f6 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -141,6 +141,7 @@ class TransformersInterface(BackendInterfaceBase): # thread_related last_request_id: Optional[str] = None ever_generated_ids: Set[int] = set() + attention_mask: torch.Tensor def __init__(self, args: ConfigArgs = default_args): self.args = args