roll back ktransformers backend, add max_tokens, max_completion_tokens param

This commit is contained in:
qiyuxinlin 2025-04-21 12:55:37 +00:00
parent a1162eea01
commit 03a65d6bea
10 changed files with 144 additions and 161 deletions

View file

@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
if pos < query_manager.query_map[query_updates[i].id].max_length:
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_completion_tokens = self.args.max_new_tokens
else:
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if temperature is None:
temperature = Config().temperature
temperature = self.args.temperature
if top_p is None:
top_p = Config().top_p
top_p = self.args.top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
return temperature, top_p, max_completion_tokens
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
temperature, top_p = self.get_sampling_params(temperature, top_p)
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
query_add.sample_options.temperature = temperature
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
if query_add.estimated_length < query_add.query_length:
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
queue = asyncio.Queue(maxsize=max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
if profiler.get_counter('decode') >= max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"

View file

@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
@ -222,16 +228,17 @@ class KTransformersInterface(TransformersInterface):
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token)
@property
def active_cache_position(self):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
yield v
# return this inference raw usage

View file

@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}")
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if is_new:
self.ever_generated_ids.clear()
same_prefix = 0
@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
expected_length = self.seq_length + max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token)
@torch.no_grad
def generate(self):
self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end(), "length"
return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0)
for i in range(1, self.max_new_tokens):
@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
# Handle tool calling
if has_tools:
# Start collecting tokens until we detect a tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()