mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
roll back ktransformers backend, add max_tokens, max_completion_tokens param
This commit is contained in:
parent
a1162eea01
commit
03a65d6bea
10 changed files with 144 additions and 161 deletions
|
@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
"<tools▁end>":"<|tool▁calls▁end|>"
|
"<tools▁end>":"<|tool▁calls▁end|>"
|
||||||
}
|
}
|
||||||
# Use check_client_connected for early stopping
|
# Use check_client_connected for early stopping
|
||||||
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
# Final return on utilization
|
# Final return on utilization
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
|
@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
"<tool▁end>":"<|tool▁call▁end|>",
|
"<tool▁end>":"<|tool▁call▁end|>",
|
||||||
"<tools▁end>":"<|tool▁calls▁end|>"
|
"<tools▁end>":"<|tool▁calls▁end|>"
|
||||||
}
|
}
|
||||||
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
usage = CompletionUsage(
|
usage = CompletionUsage(
|
||||||
|
|
|
@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/completions",tags=['openai'])
|
@router.post("/completions",tags=['openai'])
|
||||||
async def create_completion(request:Request,create:CompletionCreate):
|
async def create_completion(request:Request, create:CompletionCreate):
|
||||||
id = str(uuid4())
|
id = str(uuid4())
|
||||||
|
|
||||||
interface = get_interface()
|
interface = get_interface()
|
||||||
|
@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
else:
|
else:
|
||||||
|
@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
return stream_response(request,inner())
|
return stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
|
||||||
query_updates[i].generated_token = generated_tokens[i].item()
|
query_updates[i].generated_token = generated_tokens[i].item()
|
||||||
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
||||||
pos = query_updates[i].active_position
|
pos = query_updates[i].active_position
|
||||||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
if pos < query_manager.query_map[query_updates[i].id].max_length:
|
||||||
|
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||||||
|
|
||||||
def report_last_time_performance(profiler: Profiler):
|
def report_last_time_performance(profiler: Profiler):
|
||||||
try:
|
try:
|
||||||
|
@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
|
|
||||||
start_event.wait()
|
start_event.wait()
|
||||||
|
|
||||||
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||||
|
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
|
||||||
"""Get sampling parameters and handle default values and edge cases"""
|
"""Get sampling parameters and handle default values and edge cases"""
|
||||||
|
if max_tokens is not None:
|
||||||
|
max_completion_tokens = max_tokens
|
||||||
|
if max_completion_tokens is None:
|
||||||
|
max_completion_tokens = self.args.max_new_tokens
|
||||||
|
else:
|
||||||
|
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = Config().temperature
|
temperature = self.args.temperature
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
top_p = Config().top_p
|
top_p = self.args.top_p
|
||||||
|
|
||||||
if temperature == 0:
|
if temperature == 0:
|
||||||
temperature = 0.0001
|
temperature = 0.0001
|
||||||
if top_p == 0:
|
if top_p == 0:
|
||||||
top_p = 0.0001
|
top_p = 0.0001
|
||||||
|
|
||||||
return temperature, top_p
|
return temperature, top_p, max_completion_tokens
|
||||||
|
|
||||||
def run_queue_proxy(self):
|
def run_queue_proxy(self):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||||
|
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
profiler = Profiler()
|
profiler = Profiler()
|
||||||
profiler.create_and_start_timer("tokenize")
|
profiler.create_and_start_timer("tokenize")
|
||||||
|
|
||||||
|
@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||||
query_add.stop_criteria = stop_criteria
|
query_add.stop_criteria = stop_criteria
|
||||||
|
|
||||||
temperature, top_p = self.get_sampling_params(temperature, top_p)
|
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
|
||||||
|
|
||||||
query_add.sample_options.temperature = temperature
|
query_add.sample_options.temperature = temperature
|
||||||
query_add.sample_options.top_p = top_p
|
query_add.sample_options.top_p = top_p
|
||||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
|
||||||
|
|
||||||
if query_add.estimated_length < query_add.query_length:
|
if query_add.estimated_length < query_add.query_length:
|
||||||
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
|
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
|
||||||
|
|
||||||
query_id = self.sched_client.add_query(query_add)
|
query_id = self.sched_client.add_query(query_add)
|
||||||
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
|
queue = asyncio.Queue(maxsize=max_new_tokens)
|
||||||
self.queue_map[query_id] = queue
|
self.queue_map[query_id] = queue
|
||||||
self.thread_map[thread_id] = query_id
|
self.thread_map[thread_id] = query_id
|
||||||
is_first_token = True
|
is_first_token = True
|
||||||
|
@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
profiler.pause_timer("decode")
|
profiler.pause_timer("decode")
|
||||||
report_last_time_performance(profiler)
|
report_last_time_performance(profiler)
|
||||||
yield self.streamer.end(), None
|
yield self.streamer.end(), None
|
||||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
|
if profiler.get_counter('decode') >= max_new_tokens - 1:
|
||||||
yield "", "length"
|
yield "", "length"
|
||||||
else:
|
else:
|
||||||
yield "", "stop"
|
yield "", "stop"
|
||||||
|
|
|
@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
if max_tokens is not None:
|
||||||
|
max_completion_tokens = max_tokens
|
||||||
|
if max_completion_tokens is None:
|
||||||
|
max_new_tokens = self.args.max_new_tokens
|
||||||
|
else:
|
||||||
|
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||||||
if(input_ids_length >= self.args.cache_lens):
|
if(input_ids_length >= self.args.cache_lens):
|
||||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||||
self.seq_length = input_ids_length
|
self.seq_length = input_ids_length
|
||||||
|
@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
if getattr(self, 'generated_ids', None) is None:
|
if getattr(self, 'generated_ids', None) is None:
|
||||||
self.generated_ids = torch.zeros(
|
self.generated_ids = torch.zeros(
|
||||||
self.args.batch_size,
|
self.args.batch_size,
|
||||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
input_ids.shape[-1] + max_new_tokens + 1,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.args.device,
|
device=self.args.device,
|
||||||
)
|
)
|
||||||
|
@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
former_seq_length = self.seq_length
|
former_seq_length = self.seq_length
|
||||||
self.seq_length += input_ids_length
|
self.seq_length += input_ids_length
|
||||||
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
|
expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)
|
||||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||||
if delta_length > 0:
|
if delta_length > 0:
|
||||||
new_generate_ids = torch.zeros(
|
new_generate_ids = torch.zeros(
|
||||||
|
@ -222,6 +228,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
MLAWrapperSingleton.reset_buffer()
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
|
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -229,9 +236,9 @@ class KTransformersInterface(TransformersInterface):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
|
||||||
yield v
|
yield v
|
||||||
|
|
||||||
# return this inference raw usage
|
# return this inference raw usage
|
||||||
|
|
|
@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
if max_tokens is not None:
|
||||||
|
max_completion_tokens = max_tokens
|
||||||
|
if max_completion_tokens is None:
|
||||||
|
max_new_tokens = self.args.max_new_tokens
|
||||||
|
else:
|
||||||
|
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||||||
if is_new:
|
if is_new:
|
||||||
self.ever_generated_ids.clear()
|
self.ever_generated_ids.clear()
|
||||||
same_prefix = 0
|
same_prefix = 0
|
||||||
|
@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
if getattr(self, 'generated_ids', None) is None:
|
if getattr(self, 'generated_ids', None) is None:
|
||||||
self.generated_ids = torch.zeros(
|
self.generated_ids = torch.zeros(
|
||||||
self.args.batch_size,
|
self.args.batch_size,
|
||||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
input_ids.shape[-1] + max_new_tokens + 1,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.args.device,
|
device=self.args.device,
|
||||||
)
|
)
|
||||||
|
@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||||
former_seq_length = self.seq_length
|
former_seq_length = self.seq_length
|
||||||
self.seq_length += input_ids_length
|
self.seq_length += input_ids_length
|
||||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
expected_length = self.seq_length + max_new_tokens + 1
|
||||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||||
if delta_length > 0:
|
if delta_length > 0:
|
||||||
new_generate_ids = torch.zeros(
|
new_generate_ids = torch.zeros(
|
||||||
|
@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
|
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def generate(self):
|
def generate(self):
|
||||||
self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
|
||||||
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||||
if(self.max_new_tokens <= 0):
|
if(self.max_new_tokens <= 0):
|
||||||
logger.warning("max_new_tokens is less than 0")
|
logger.warning("max_new_tokens is less than 0")
|
||||||
yield self.streamer.end(), "length"
|
yield self.streamer.end(), "length"
|
||||||
return
|
return
|
||||||
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
|
||||||
self.profiler.set_counter("decode", 0)
|
self.profiler.set_counter("decode", 0)
|
||||||
|
|
||||||
for i in range(1, self.max_new_tokens):
|
for i in range(1, self.max_new_tokens):
|
||||||
|
@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||||
self.streamer.reset()
|
self.streamer.reset()
|
||||||
self.profiler.create_and_start_timer("tokenize")
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
|
|
||||||
# Check if tools are present
|
|
||||||
has_tools = tools is not None and len(tools) > 0
|
|
||||||
|
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||||
elif isinstance(local_messages, str):
|
elif isinstance(local_messages, str):
|
||||||
|
#local_messages = local_messages[0]['content']
|
||||||
input_ids = self.tokenize_prompt(local_messages)
|
input_ids = self.tokenize_prompt(local_messages)
|
||||||
|
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
|
|
||||||
|
@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.profiler.pause_timer("tokenize")
|
self.profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("prefill")
|
self.profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
|
@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
print(think, end="",flush=True)
|
print(think, end="",flush=True)
|
||||||
yield think, None
|
yield think, None
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):
|
||||||
|
# output think token after prefill done
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
yield t, None
|
yield t, None
|
||||||
self.profiler.pause_timer("prefill")
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("decode")
|
self.profiler.create_and_start_timer("decode")
|
||||||
|
for t, finish_reason in self.generate():
|
||||||
# Handle tool calling
|
if t is not None:
|
||||||
if has_tools:
|
print(t, end="",flush=True)
|
||||||
# Start collecting tokens until we detect a tool call
|
yield t, finish_reason
|
||||||
collected_tokens = ""
|
print("")
|
||||||
is_collecting_tool_call = False
|
|
||||||
is_function_name_collected = False
|
|
||||||
function_name = ""
|
|
||||||
collected_arguments = ""
|
|
||||||
brackets_count = 0
|
|
||||||
|
|
||||||
for t, finish_reason in self.generate():
|
|
||||||
if t is not None:
|
|
||||||
print(t, end="", flush=True)
|
|
||||||
collected_tokens += t
|
|
||||||
|
|
||||||
# Check if we're starting a tool call
|
|
||||||
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
|
|
||||||
is_collecting_tool_call = True
|
|
||||||
|
|
||||||
# Generate a unique tool call ID
|
|
||||||
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
|
|
||||||
|
|
||||||
# Send first tool call info
|
|
||||||
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
|
|
||||||
# If tools are provided, use the first one's name
|
|
||||||
recommended_function = tools[0].function.name
|
|
||||||
else:
|
|
||||||
# Otherwise try to extract from context
|
|
||||||
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
|
||||||
recommended_function = function_match.group(1) if function_match else ""
|
|
||||||
|
|
||||||
yield {
|
|
||||||
'tool_call': {
|
|
||||||
'id': tool_call_id,
|
|
||||||
'type': 'function',
|
|
||||||
'index': 0,
|
|
||||||
'function': {
|
|
||||||
'name': recommended_function,
|
|
||||||
'arguments': ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'first_chunk': True
|
|
||||||
}
|
|
||||||
|
|
||||||
# Extract function name if we're collecting tool call
|
|
||||||
if is_collecting_tool_call and not is_function_name_collected:
|
|
||||||
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
|
||||||
if name_match:
|
|
||||||
function_name = name_match.group(1)
|
|
||||||
is_function_name_collected = True
|
|
||||||
|
|
||||||
# Track argument collection
|
|
||||||
if is_collecting_tool_call and is_function_name_collected:
|
|
||||||
args_position = collected_tokens.find('"arguments"')
|
|
||||||
if args_position > -1:
|
|
||||||
# Find the start of the JSON object after "arguments":
|
|
||||||
json_start = collected_tokens.find('{', args_position)
|
|
||||||
if json_start > -1:
|
|
||||||
for i in range(json_start, len(collected_tokens)):
|
|
||||||
char = collected_tokens[i]
|
|
||||||
collected_arguments += char
|
|
||||||
|
|
||||||
if char == '{':
|
|
||||||
brackets_count += 1
|
|
||||||
elif char == '}':
|
|
||||||
brackets_count -= 1
|
|
||||||
|
|
||||||
# Check if we've completed the arguments JSON
|
|
||||||
if brackets_count == 0:
|
|
||||||
# Send argument chunk
|
|
||||||
yield {
|
|
||||||
'tool_call': {
|
|
||||||
'id': tool_call_id,
|
|
||||||
'type': 'function',
|
|
||||||
'function': {
|
|
||||||
'name': function_name,
|
|
||||||
'arguments': collected_arguments
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'argument_chunk': collected_arguments,
|
|
||||||
'last_chunk': True,
|
|
||||||
'prompt_tokens': 176,
|
|
||||||
'completion_tokens': 20
|
|
||||||
}
|
|
||||||
# Reset for next potential tool call
|
|
||||||
collected_tokens = ""
|
|
||||||
is_collecting_tool_call = False
|
|
||||||
is_function_name_collected = False
|
|
||||||
function_name = ""
|
|
||||||
collected_arguments = ""
|
|
||||||
brackets_count = 0
|
|
||||||
break
|
|
||||||
|
|
||||||
# Handle finish reason
|
|
||||||
if finish_reason is not None:
|
|
||||||
yield "", finish_reason
|
|
||||||
|
|
||||||
print("")
|
|
||||||
else:
|
|
||||||
# Regular text generation (no tools)
|
|
||||||
for t, finish_reason in self.generate():
|
|
||||||
if t is not None:
|
|
||||||
print(t, end="",flush=True)
|
|
||||||
yield t, finish_reason
|
|
||||||
print("")
|
|
||||||
|
|
||||||
self.profiler.pause_timer("decode")
|
self.profiler.pause_timer("decode")
|
||||||
self.report_last_time_performance()
|
self.report_last_time_performance()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ktransformers.server.schemas.base import Object
|
from ktransformers.server.schemas.base import Object
|
||||||
|
@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
|
@ -67,6 +65,8 @@ class ChatCompletionCreate(BaseModel):
|
||||||
stream_options: Optional[Dict[str, Any]] = None
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
frequency_penalty: float = 0
|
frequency_penalty: float = 0
|
||||||
presence_penalty: float = 0
|
presence_penalty: float = 0
|
||||||
|
max_tokens: Optional[int] = Field(default=50)
|
||||||
|
max_completion_tokens: Optional[int] = Field(default=50)
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..base import Object
|
from ..base import Object
|
||||||
|
|
||||||
|
@ -9,8 +8,10 @@ class CompletionCreate(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: str | List[str]
|
prompt: str | List[str]
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = Field(default=0.6)
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = Field(default=1)
|
||||||
|
max_tokens: Optional[int] = Field(default=50)
|
||||||
|
max_completion_tokens: Optional[int] = Field(default=50)
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|
45
ktransformers/tests/function_call_test.py
Normal file
45
ktransformers/tests/function_call_test.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
def send_messages(messages):
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools
|
||||||
|
)
|
||||||
|
return response.choices[0].message
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="http://0.0.0.0:10002/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather of an location, the user shoud supply a location first",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}]
|
||||||
|
message = send_messages(messages)
|
||||||
|
print(f"User>\t {messages[0]['content']}")
|
||||||
|
print(message)
|
||||||
|
tool = message.tool_calls[0]
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
messages.append({"role": "tool", "tool_call_id": tool.id, "content": "24℃"})
|
||||||
|
message = send_messages(messages)
|
||||||
|
print(f"Model>\t {message.content}")
|
|
@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions"
|
||||||
bf_list = [1]
|
bf_list = [1]
|
||||||
decodesz_list = [128]
|
decodesz_list = [128]
|
||||||
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
||||||
async def fetch_event_stream(session, request_id):
|
async def fetch_event_stream(session, payload, request_id):
|
||||||
try:
|
try:
|
||||||
payload = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": ""},
|
|
||||||
{"role": "user", "content": prompt_list[request_id]}
|
|
||||||
],
|
|
||||||
"model": "DeepSeek-V3",
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"stream": True # 开启流式输出
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
|
@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id):
|
||||||
|
|
||||||
async def main(prompt_id):
|
async def main(prompt_id):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
tasks = [fetch_event_stream(session, prompt_id)]
|
payload = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": prompt_list[prompt_id]}
|
||||||
|
],
|
||||||
|
"model": "DeepSeek-V3",
|
||||||
|
"stream": True,
|
||||||
|
"max_completion_tokens": 2,
|
||||||
|
# "temperature": 0.3,
|
||||||
|
# "top_p": 1.0,
|
||||||
|
# "max_tokens" : 20,
|
||||||
|
}
|
||||||
|
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
payload["temperature"] = 0.3
|
||||||
|
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
payload["top_p"] = 1
|
||||||
|
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
payload["max_tokens"] = 200
|
||||||
|
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
payload["stream"] = False
|
||||||
|
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
2
third_party/llamafile/iqk_mul_mat.inc
vendored
2
third_party/llamafile/iqk_mul_mat.inc
vendored
|
@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||||
|
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
printf("case:%d",typeA);
|
// printf("case:%d",typeA);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue