diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index e1ff30e..a5eb986 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "":"<|tool▁calls▁end|>" } # Use check_client_connected for early stopping - async for res in interface.inference(input_message, id, create.temperature, create.top_p): + async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): # Final return on utilization raw_usage = res @@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "":"<|tool▁call▁end|>", "":"<|tool▁calls▁end|>" } - async for res in interface.inference(input_message, id, create.temperature, create.top_p): + async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res usage = CompletionUsage( diff --git a/ktransformers/server/api/openai/legacy/completions.py b/ktransformers/server/api/openai/legacy/completions.py index 7ce2d2a..e46729f 100644 --- a/ktransformers/server/api/openai/legacy/completions.py +++ b/ktransformers/server/api/openai/legacy/completions.py @@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage router = APIRouter() @router.post("/completions",tags=['openai']) -async def create_completion(request:Request,create:CompletionCreate): +async def create_completion(request:Request, create:CompletionCreate): id = str(uuid4()) interface = get_interface() @@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): if create.stream: async def inner(): - async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): + async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res else: @@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate): return stream_response(request,inner()) else: comp = CompletionObject(id=id,object='text_completion',created=int(time())) - async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): + async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res else: diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 582fabb..74c680d 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_ query_updates[i].generated_token = generated_tokens[i].item() if not query_manager.query_map[query_updates[i].id].is_prefill: pos = query_updates[i].active_position - query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] + if pos < query_manager.query_map[query_updates[i].id].max_length: + query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] def report_last_time_performance(profiler: Profiler): try: @@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase): start_event.wait() - def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: + def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None, + max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]: """Get sampling parameters and handle default values and edge cases""" + if max_tokens is not None: + max_completion_tokens = max_tokens + if max_completion_tokens is None: + max_completion_tokens = self.args.max_new_tokens + else: + max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens) if temperature is None: - temperature = Config().temperature + temperature = self.args.temperature if top_p is None: - top_p = Config().top_p + top_p = self.args.top_p if temperature == 0: temperature = 0.0001 if top_p == 0: top_p = 0.0001 - return temperature, top_p + return temperature, top_p, max_completion_tokens def run_queue_proxy(self): loop = asyncio.new_event_loop() @@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase): logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids - async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, + max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): profiler = Profiler() profiler.create_and_start_timer("tokenize") @@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase): stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] query_add.stop_criteria = stop_criteria - temperature, top_p = self.get_sampling_params(temperature, top_p) + temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens) query_add.sample_options.temperature = temperature query_add.sample_options.top_p = top_p - query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) + query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens) if query_add.estimated_length < query_add.query_length: raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}') query_id = self.sched_client.add_query(query_add) - queue = asyncio.Queue(maxsize=self.args.max_new_tokens) + queue = asyncio.Queue(maxsize=max_new_tokens) self.queue_map[query_id] = queue self.thread_map[thread_id] = query_id is_first_token = True @@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase): profiler.pause_timer("decode") report_last_time_performance(profiler) yield self.streamer.end(), None - if profiler.get_counter('decode') >= self.args.max_new_tokens - 1: + if profiler.get_counter('decode') >= max_new_tokens - 1: yield "", "length" else: yield "", "stop" diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 690d09b..fd2a808 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): input_ids_length = input_ids.shape[-1] + if max_tokens is not None: + max_completion_tokens = max_tokens + if max_completion_tokens is None: + max_new_tokens = self.args.max_new_tokens + else: + max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens) if(input_ids_length >= self.args.cache_lens): logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") self.seq_length = input_ids_length @@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface): if getattr(self, 'generated_ids', None) is None: self.generated_ids = torch.zeros( self.args.batch_size, - input_ids.shape[-1] + self.args.max_new_tokens + 1, + input_ids.shape[-1] + max_new_tokens + 1, dtype=torch.int, device=self.args.device, ) @@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface): former_seq_length = self.seq_length self.seq_length += input_ids_length - expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens) + expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens) delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: new_generate_ids = torch.zeros( @@ -222,16 +228,17 @@ class KTransformersInterface(TransformersInterface): MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) + self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 yield self.append_new_tokens(next_token) - + @property def active_cache_position(self): device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") return torch.tensor([self.seq_length - 1], device=device) - async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): async with self._infer_lock: - async for v in super().inference(local_messages, thread_id, temperature, top_p, tools): + async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens): yield v # return this inference raw usage diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index df6a171..6bde540 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase): return self.logits_to_token(logits) @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): input_ids_length = input_ids.shape[-1] logger.debug(f"input_ids: {input_ids.shape}") - + if max_tokens is not None: + max_completion_tokens = max_tokens + if max_completion_tokens is None: + max_new_tokens = self.args.max_new_tokens + else: + max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens) if is_new: self.ever_generated_ids.clear() same_prefix = 0 @@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase): if getattr(self, 'generated_ids', None) is None: self.generated_ids = torch.zeros( self.args.batch_size, - input_ids.shape[-1] + self.args.max_new_tokens + 1, + input_ids.shape[-1] + max_new_tokens + 1, dtype=torch.int, device=self.args.device, ) @@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase): logger.debug(f"generate_ids: {self.generated_ids.shape}") former_seq_length = self.seq_length self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens + 1 + expected_length = self.seq_length + max_new_tokens + 1 delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: new_generate_ids = torch.zeros( @@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase): self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) + self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 yield self.append_new_tokens(next_token) @torch.no_grad def generate(self): - self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1 logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") if(self.max_new_tokens <= 0): logger.warning("max_new_tokens is less than 0") yield self.streamer.end(), "length" return - logger.info(f"max_new_tokens: {self.max_new_tokens}") self.profiler.set_counter("decode", 0) for i in range(1, self.max_new_tokens): @@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase): self.last_request_id = thread_id return True - async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): self.streamer.reset() self.profiler.create_and_start_timer("tokenize") - - # Check if tools are present - has_tools = tools is not None and len(tools) > 0 - if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) elif isinstance(local_messages, str): + #local_messages = local_messages[0]['content'] input_ids = self.tokenize_prompt(local_messages) + #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") @@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase): ) self.profiler.pause_timer("tokenize") + self.profiler.create_and_start_timer("prefill") if Config().user_force_think: @@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase): print(think, end="",flush=True) yield think, None - for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): + for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens): + # output think token after prefill done if t is not None: print(t, end="",flush=True) yield t, None self.profiler.pause_timer("prefill") self.profiler.create_and_start_timer("decode") - - # Handle tool calling - if has_tools: - # Start collecting tokens until we detect a tool call - collected_tokens = "" - is_collecting_tool_call = False - is_function_name_collected = False - function_name = "" - collected_arguments = "" - brackets_count = 0 - - for t, finish_reason in self.generate(): - if t is not None: - print(t, end="", flush=True) - collected_tokens += t - - # Check if we're starting a tool call - if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']): - is_collecting_tool_call = True - - # Generate a unique tool call ID - tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}" - - # Send first tool call info - if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'): - # If tools are provided, use the first one's name - recommended_function = tools[0].function.name - else: - # Otherwise try to extract from context - function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens) - recommended_function = function_match.group(1) if function_match else "" - - yield { - 'tool_call': { - 'id': tool_call_id, - 'type': 'function', - 'index': 0, - 'function': { - 'name': recommended_function, - 'arguments': "" - } - }, - 'first_chunk': True - } - - # Extract function name if we're collecting tool call - if is_collecting_tool_call and not is_function_name_collected: - name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens) - if name_match: - function_name = name_match.group(1) - is_function_name_collected = True - - # Track argument collection - if is_collecting_tool_call and is_function_name_collected: - args_position = collected_tokens.find('"arguments"') - if args_position > -1: - # Find the start of the JSON object after "arguments": - json_start = collected_tokens.find('{', args_position) - if json_start > -1: - for i in range(json_start, len(collected_tokens)): - char = collected_tokens[i] - collected_arguments += char - - if char == '{': - brackets_count += 1 - elif char == '}': - brackets_count -= 1 - - # Check if we've completed the arguments JSON - if brackets_count == 0: - # Send argument chunk - yield { - 'tool_call': { - 'id': tool_call_id, - 'type': 'function', - 'function': { - 'name': function_name, - 'arguments': collected_arguments - } - }, - 'argument_chunk': collected_arguments, - 'last_chunk': True, - 'prompt_tokens': 176, - 'completion_tokens': 20 - } - # Reset for next potential tool call - collected_tokens = "" - is_collecting_tool_call = False - is_function_name_collected = False - function_name = "" - collected_arguments = "" - brackets_count = 0 - break - - # Handle finish reason - if finish_reason is not None: - yield "", finish_reason - - print("") - else: - # Regular text generation (no tools) - for t, finish_reason in self.generate(): - if t is not None: - print(t, end="",flush=True) - yield t, finish_reason - print("") - + for t, finish_reason in self.generate(): + if t is not None: + print(t, end="",flush=True) + yield t, finish_reason + print("") self.profiler.pause_timer("decode") self.report_last_time_performance() diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index a471b28..643c81c 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -1,7 +1,6 @@ from typing import List, Optional, Union, Dict, Any from typing_extensions import Literal from enum import Enum - from pydantic import BaseModel, Field from ktransformers.server.schemas.base import Object @@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice from uuid import uuid4 -from pydantic import BaseModel, Field class Role(Enum): system = 'system' @@ -67,7 +65,9 @@ class ChatCompletionCreate(BaseModel): stream_options: Optional[Dict[str, Any]] = None frequency_penalty: float = 0 presence_penalty: float = 0 - + max_tokens: Optional[int] = Field(default=50) + max_completion_tokens: Optional[int] = Field(default=50) + def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/schemas/legacy/completions.py b/ktransformers/server/schemas/legacy/completions.py index ea936ea..2d83212 100644 --- a/ktransformers/server/schemas/legacy/completions.py +++ b/ktransformers/server/schemas/legacy/completions.py @@ -1,7 +1,6 @@ from typing import List, Optional from enum import Enum - -from pydantic import BaseModel +from pydantic import BaseModel, Field from ..base import Object @@ -9,9 +8,11 @@ class CompletionCreate(BaseModel): model: str prompt: str | List[str] stream: bool = False - temperature: Optional[float] = None - top_p: Optional[float] = None - + temperature: Optional[float] = Field(default=0.6) + top_p: Optional[float] = Field(default=1) + max_tokens: Optional[int] = Field(default=50) + max_completion_tokens: Optional[int] = Field(default=50) + def get_tokenizer_messages(self): if isinstance(self.prompt,List): self.get_tokenizer_messages('\n'.join(self.prompt)) diff --git a/ktransformers/tests/function_call_test.py b/ktransformers/tests/function_call_test.py new file mode 100644 index 0000000..a5d6569 --- /dev/null +++ b/ktransformers/tests/function_call_test.py @@ -0,0 +1,45 @@ +from openai import OpenAI + +def send_messages(messages): + response = client.chat.completions.create( + model="deepseek-chat", + messages=messages, + tools=tools + ) + return response.choices[0].message + +client = OpenAI( + api_key="placeholder", + base_url="http://0.0.0.0:10002/v1", +) + +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather of an location, the user shoud supply a location first", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"] + }, + } + }, +] + +messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}] +message = send_messages(messages) +print(f"User>\t {messages[0]['content']}") +print(message) +tool = message.tool_calls[0] +messages.append(message) + +messages.append({"role": "tool", "tool_call_id": tool.id, "content": "24℃"}) +message = send_messages(messages) +print(f"Model>\t {message.content}") \ No newline at end of file diff --git a/ktransformers/tests/test_client.py b/ktransformers/tests/test_client.py index d4619a9..1f6b684 100644 --- a/ktransformers/tests/test_client.py +++ b/ktransformers/tests/test_client.py @@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions" bf_list = [1] decodesz_list = [128] prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke '] -async def fetch_event_stream(session, request_id): +async def fetch_event_stream(session, payload, request_id): try: - payload = { - "messages": [ - {"role": "system", "content": ""}, - {"role": "user", "content": prompt_list[request_id]} - ], - "model": "DeepSeek-V3", - "temperature": 0.3, - "top_p": 1.0, - "stream": True # 开启流式输出 - } + headers = { 'accept': 'application/json', @@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id): async def main(prompt_id): async with aiohttp.ClientSession() as session: - tasks = [fetch_event_stream(session, prompt_id)] + payload = { + "messages": [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt_list[prompt_id]} + ], + "model": "DeepSeek-V3", + "stream": True, + "max_completion_tokens": 2, + # "temperature": 0.3, + # "top_p": 1.0, + # "max_tokens" : 20, + } + tasks = [fetch_event_stream(session, payload, prompt_id)] + await asyncio.gather(*tasks) + + payload["temperature"] = 0.3 + tasks = [fetch_event_stream(session, payload, prompt_id)] + await asyncio.gather(*tasks) + + payload["top_p"] = 1 + tasks = [fetch_event_stream(session, payload, prompt_id)] + await asyncio.gather(*tasks) + + payload["max_tokens"] = 200 + tasks = [fetch_event_stream(session, payload, prompt_id)] + await asyncio.gather(*tasks) + + payload["stream"] = False + tasks = [fetch_event_stream(session, payload, prompt_id)] await asyncio.gather(*tasks) if __name__ == "__main__": diff --git a/third_party/llamafile/iqk_mul_mat.inc b/third_party/llamafile/iqk_mul_mat.inc index 694467f..3dee90b 100644 --- a/third_party/llamafile/iqk_mul_mat.inc +++ b/third_party/llamafile/iqk_mul_mat.inc @@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { default: { - printf("case:%d",typeA); + // printf("case:%d",typeA); return false; }