roll back ktransformers backend, add max_tokens, max_completion_tokens param

This commit is contained in:
qiyuxinlin 2025-04-21 12:55:37 +00:00
parent a1162eea01
commit 03a65d6bea
10 changed files with 144 additions and 161 deletions

View file

@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tools▁end>":"<tool▁calls▁end>" "<tools▁end>":"<tool▁calls▁end>"
} }
# Use check_client_connected for early stopping # Use check_client_connected for early stopping
async for res in interface.inference(input_message, id, create.temperature, create.top_p): async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
# Final return on utilization # Final return on utilization
raw_usage = res raw_usage = res
@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tool▁end>":"<tool▁call▁end>", "<tool▁end>":"<tool▁call▁end>",
"<tools▁end>":"<tool▁calls▁end>" "<tools▁end>":"<tool▁calls▁end>"
} }
async for res in interface.inference(input_message, id, create.temperature, create.top_p): async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
usage = CompletionUsage( usage = CompletionUsage(

View file

@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request,create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
interface = get_interface() interface = get_interface()
@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
else: else:
@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
else: else:

View file

@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates[i].generated_token = generated_tokens[i].item() query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill: if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] if pos < query_manager.query_map[query_updates[i].id].max_length:
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler): def report_last_time_performance(profiler: Profiler):
try: try:
@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases""" """Get sampling parameters and handle default values and edge cases"""
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_completion_tokens = self.args.max_new_tokens
else:
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if temperature is None: if temperature is None:
temperature = Config().temperature temperature = self.args.temperature
if top_p is None: if top_p is None:
top_p = Config().top_p top_p = self.args.top_p
if temperature == 0: if temperature == 0:
temperature = 0.0001 temperature = 0.0001
if top_p == 0: if top_p == 0:
top_p = 0.0001 top_p = 0.0001
return temperature, top_p return temperature, top_p, max_completion_tokens
def run_queue_proxy(self): def run_queue_proxy(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger.debug(f"get input ids of shape {input_ids.shape}") logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
profiler = Profiler() profiler = Profiler()
profiler.create_and_start_timer("tokenize") profiler.create_and_start_timer("tokenize")
@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria query_add.stop_criteria = stop_criteria
temperature, top_p = self.get_sampling_params(temperature, top_p) temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
query_add.sample_options.temperature = temperature query_add.sample_options.temperature = temperature
query_add.sample_options.top_p = top_p query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
if query_add.estimated_length < query_add.query_length: if query_add.estimated_length < query_add.query_length:
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}') raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
query_id = self.sched_client.add_query(query_add) query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens) queue = asyncio.Queue(maxsize=max_new_tokens)
self.queue_map[query_id] = queue self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id self.thread_map[thread_id] = query_id
is_first_token = True is_first_token = True
@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.pause_timer("decode") profiler.pause_timer("decode")
report_last_time_performance(profiler) report_last_time_performance(profiler)
yield self.streamer.end(), None yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1: if profiler.get_counter('decode') >= max_new_tokens - 1:
yield "", "length" yield "", "length"
else: else:
yield "", "stop" yield "", "stop"

View file

@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if(input_ids_length >= self.args.cache_lens): if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length self.seq_length = input_ids_length
@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
if getattr(self, 'generated_ids', None) is None: if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1, input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens) expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1] delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0: if delta_length > 0:
new_generate_ids = torch.zeros( new_generate_ids = torch.zeros(
@ -222,6 +228,7 @@ class KTransformersInterface(TransformersInterface):
MLAWrapperSingleton.reset_buffer() MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@property @property
@ -229,9 +236,9 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools): async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
yield v yield v
# return this inference raw usage # return this inference raw usage

View file

@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if is_new: if is_new:
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
same_prefix = 0 same_prefix = 0
@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
if getattr(self, 'generated_ids', None) is None: if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1, input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.debug(f"generate_ids: {self.generated_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1 expected_length = self.seq_length + max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1] delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0: if delta_length > 0:
new_generate_ids = torch.zeros( new_generate_ids = torch.zeros(
@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
self.prepare_logits_wrapper(input_ids, device, temperature, top_p) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@torch.no_grad @torch.no_grad
def generate(self): def generate(self):
self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0): if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end(), "length" yield self.streamer.end(), "length"
return return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
for i in range(1, self.max_new_tokens): for i in range(1, self.max_new_tokens):
@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
) )
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True) print(think, end="",flush=True)
yield think, None yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):
# output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t, None yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
# Handle tool calling if t is not None:
if has_tools: print(t, end="",flush=True)
# Start collecting tokens until we detect a tool call yield t, finish_reason
collected_tokens = "" print("")
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()

View file

@ -1,7 +1,6 @@
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, Field
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
@ -67,6 +65,8 @@ class ChatCompletionCreate(BaseModel):
stream_options: Optional[Dict[str, Any]] = None stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0 frequency_penalty: float = 0
presence_penalty: float = 0 presence_penalty: float = 0
max_tokens: Optional[int] = Field(default=50)
max_completion_tokens: Optional[int] = Field(default=50)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]

View file

@ -1,7 +1,6 @@
from typing import List, Optional from typing import List, Optional
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field
from pydantic import BaseModel
from ..base import Object from ..base import Object
@ -9,8 +8,10 @@ class CompletionCreate(BaseModel):
model: str model: str
prompt: str | List[str] prompt: str | List[str]
stream: bool = False stream: bool = False
temperature: Optional[float] = None temperature: Optional[float] = Field(default=0.6)
top_p: Optional[float] = None top_p: Optional[float] = Field(default=1)
max_tokens: Optional[int] = Field(default=50)
max_completion_tokens: Optional[int] = Field(default=50)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):

View file

@ -0,0 +1,45 @@
from openai import OpenAI
def send_messages(messages):
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
tools=tools
)
return response.choices[0].message
client = OpenAI(
api_key="placeholder",
base_url="http://0.0.0.0:10002/v1",
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather of an location, the user shoud supply a location first",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"]
},
}
},
]
messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}]
message = send_messages(messages)
print(f"User>\t {messages[0]['content']}")
print(message)
tool = message.tool_calls[0]
messages.append(message)
messages.append({"role": "tool", "tool_call_id": tool.id, "content": "24℃"})
message = send_messages(messages)
print(f"Model>\t {message.content}")

View file

@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1] bf_list = [1]
decodesz_list = [128] decodesz_list = [128]
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke '] prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
async def fetch_event_stream(session, request_id): async def fetch_event_stream(session, payload, request_id):
try: try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = { headers = {
'accept': 'application/json', 'accept': 'application/json',
@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id):
async def main(prompt_id): async def main(prompt_id):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)] payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[prompt_id]}
],
"model": "DeepSeek-V3",
"stream": True,
"max_completion_tokens": 2,
# "temperature": 0.3,
# "top_p": 1.0,
# "max_tokens" : 20,
}
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["temperature"] = 0.3
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["top_p"] = 1
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["max_tokens"] = 200
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["stream"] = False
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
default: default:
{ {
printf("case:%d",typeA); // printf("case:%d",typeA);
return false; return false;
} }