mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
✨: refactor local_chat and fix message slice bug in server
This commit is contained in:
parent
43fc7f44a6
commit
dd1d8667f3
13 changed files with 549 additions and 405 deletions
|
@ -1,6 +1,12 @@
|
|||
import torch
|
||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer
|
||||
from ktransformers.server.backend.interfaces.transformers import (
|
||||
TransformersInterface,
|
||||
ConfigArgs,
|
||||
TransformersThreadContext,
|
||||
default_args,
|
||||
TextStreamer,
|
||||
)
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
|
@ -14,71 +20,85 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
|
||||
|
||||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self,args:ConfigArgs= default_args):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device)
|
||||
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation="flash_attention_2"
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
with torch.device("meta"):
|
||||
self.model=custom_models[config.architectures[0]](config)
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
gguf_path = args.gguf_path
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {device_map}')
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=device_map, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}')
|
||||
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
|
||||
self.cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=args.batch_size,
|
||||
max_cache_len=args.cache_lens,
|
||||
device=device_map,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('blk.0.self_attn', device_map)
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
|
||||
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=torch_device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0,-1,:]
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1,self.seq_length)).to(torch_device)
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
return_dict=False
|
||||
)[0]
|
||||
logits = logits[0,-1,:]
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
from typing import Any, List, Optional, Set
|
||||
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig
|
||||
from transformers import (
|
||||
LlamaTokenizer,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
LlamaForCausalLM,
|
||||
GenerationConfig,
|
||||
StaticCache,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import torch
|
||||
import sys, os
|
||||
from ..base import ThreadContext,BackendInterfaceBase
|
||||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
from ktransformers.server.config.log import logger
|
||||
from ..args import ConfigArgs,default_args
|
||||
|
||||
from ..args import ConfigArgs, default_args
|
||||
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
|
@ -28,21 +36,20 @@ class TextStreamer:
|
|||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
|
||||
def put(self, value)->Optional[str]:
|
||||
def put(self, value) -> Optional[str]:
|
||||
"""
|
||||
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||
"""
|
||||
if not isinstance(value,int):
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("TextStreamer only supports batch size 1, and int type input")
|
||||
|
||||
|
||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||
self.next_tokens_are_prompt = False
|
||||
return None
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.append(value)
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
|
||||
|
||||
# After the symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
|
@ -59,7 +66,7 @@ class TextStreamer:
|
|||
self.print_len += len(printable_text)
|
||||
return printable_text
|
||||
|
||||
def end(self)->Optional[str]:
|
||||
def end(self) -> Optional[str]:
|
||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
|
@ -71,7 +78,7 @@ class TextStreamer:
|
|||
|
||||
self.next_tokens_are_prompt = True
|
||||
return printable_text
|
||||
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
|
@ -97,101 +104,91 @@ class TextStreamer:
|
|||
return False
|
||||
|
||||
|
||||
class TransformersThreadContext(ThreadContext):
|
||||
class TransformersThreadContext(ThreadContext):
|
||||
def get_local_messages(self):
|
||||
local_messages = []
|
||||
for m in self.messages:
|
||||
local_messages.append(
|
||||
{'role':m.role.value,
|
||||
'content':m.get_text_content()}
|
||||
)
|
||||
|
||||
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
|
||||
|
||||
return local_messages
|
||||
|
||||
|
||||
class TransformersInterface(BackendInterfaceBase):
|
||||
use_static_cache : bool = True
|
||||
|
||||
use_static_cache: bool = True
|
||||
|
||||
model: Any
|
||||
tokenizer: AutoTokenizer
|
||||
|
||||
|
||||
cache: StaticCache
|
||||
generated_ids:torch.Tensor
|
||||
seq_length:int
|
||||
|
||||
generated_ids: torch.Tensor
|
||||
seq_length: int
|
||||
|
||||
streamer: TextStreamer
|
||||
|
||||
# thread_related
|
||||
last_request_id: Optional[str] = None
|
||||
ever_generated_ids: Set[int] = set()
|
||||
|
||||
|
||||
|
||||
def __init__(self, args:ConfigArgs = default_args):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True)
|
||||
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
|
||||
|
||||
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
|
||||
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
|
||||
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
|
||||
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
|
||||
|
||||
self.cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=args.batch_size,
|
||||
max_cache_len=args.cache_lens,
|
||||
device=args.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
|
||||
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
@property
|
||||
def current_ids(self):
|
||||
return self.generated_ids[:,self.seq_length-1].unsqueeze(1)
|
||||
|
||||
return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)
|
||||
|
||||
@property
|
||||
def active_cache_position(self):
|
||||
return torch.tensor([self.seq_length-1], device=self.args.device)
|
||||
return torch.tensor([self.seq_length - 1], device=self.args.device)
|
||||
|
||||
|
||||
def tokenize_prompt(self,prompt:str):
|
||||
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
|
||||
def tokenize_prompt(self, prompt: str):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
|
||||
return input_ids
|
||||
|
||||
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List):
|
||||
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
|
||||
for m in messages:
|
||||
if m['role']=='system':
|
||||
logger.warn(f'change {m["role"]} to user')
|
||||
m['role'] = 'user'
|
||||
if m["role"] == "system":
|
||||
logger.warning(f'change {m["role"]} to user')
|
||||
m["role"] = "user"
|
||||
|
||||
new_messages = [messages[0]]
|
||||
for m in messages[1:]:
|
||||
if m['role'] == 'user' and new_messages[-1]['role']=='user':
|
||||
logger.warn('merge two adjacent user messages')
|
||||
new_messages[-1]['content']+=m['content']
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
new_messages.append(m)
|
||||
|
||||
|
||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
# We can only hope that the input_ids are the same
|
||||
unequal_mask = torch.ne(x,y)
|
||||
unequal_positions = torch.nonzero(unequal_mask)
|
||||
num_unequal_elements = unequal_mask.sum().item()
|
||||
logger.warn(f'num_unequal_elements: {num_unequal_elements}')
|
||||
|
||||
input_ids = input_ids[:,self.seq_length:]
|
||||
logger.debug(f'get input ids of shape {input_ids.shape}')
|
||||
input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device)
|
||||
else:
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
).to(self.args.device)
|
||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
def append_new_tokens(self,new_tokens:int)->Optional[str]:
|
||||
self.generated_ids[0,self.seq_length] = new_tokens
|
||||
self.seq_length+=1
|
||||
|
||||
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
|
||||
self.generated_ids[0, self.seq_length] = new_tokens
|
||||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self,logits:torch.Tensor):
|
||||
logits = logits/self.args.temperature
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
|
@ -200,7 +197,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
logits[token_idx] /= self.args.repetition_penalty
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
sample = True
|
||||
if sample:
|
||||
last = torch.multinomial(probs, num_samples=1)
|
||||
|
@ -211,127 +208,124 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.ever_generated_ids.add(last)
|
||||
return last
|
||||
|
||||
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
return_dict=False
|
||||
)[0]
|
||||
logits = logits[0,-1,:]
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self,input_ids:torch.Tensor,is_new:bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter('prefill',input_ids_length)
|
||||
logger.debug(f'input_ids: {input_ids.shape}')
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f'generate_ids: {self.generated_ids.shape}')
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens+1
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length>0:
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1)
|
||||
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}')
|
||||
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1,self.seq_length)).to(self.args.device)
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
device = input_ids.device
|
||||
if not(type(self) is TransformersInterface):
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,return_dict=False
|
||||
)[0]
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
|
||||
|
||||
next_token = self.logits_to_token(logits[0,-1,:])
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@torch.no_grad
|
||||
def generate(self):
|
||||
self.profiler.set_counter('decode',0)
|
||||
self.profiler.set_counter("decode", 0)
|
||||
for _ in range(1, self.args.max_new_tokens):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc('decode')
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
yield self.streamer.end()
|
||||
|
||||
def check_is_new(self,thread_id:str):
|
||||
def check_is_new(self, thread_id: str):
|
||||
if not self.use_static_cache:
|
||||
return True
|
||||
if self.last_request_id is None:
|
||||
self.last_request_id = thread_id
|
||||
return True
|
||||
else:
|
||||
if self.last_request_id==thread_id:
|
||||
if self.last_request_id == thread_id:
|
||||
return False
|
||||
else:
|
||||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self,local_messages,thread_id:str):
|
||||
self.profiler.create_and_start_timer('tokenize')
|
||||
if isinstance(local_messages,List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages)
|
||||
elif isinstance(local_messages,str):
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
else:
|
||||
raise ValueError('local_messages should be List or str')
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
self.profiler.pause_timer('tokenize')
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer('prefill')
|
||||
for t in self.prefill(input_ids,self.check_is_new(thread_id)):
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
if t is not None:
|
||||
print(t,end='')
|
||||
print(t, end="")
|
||||
yield t
|
||||
self.profiler.pause_timer('prefill')
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer('decode')
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t in self.generate():
|
||||
if t is not None:
|
||||
print(t,end='')
|
||||
print(t, end="")
|
||||
yield t
|
||||
print('')
|
||||
self.profiler.pause_timer('decode')
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue