mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
change test
This commit is contained in:
parent
b17ab8653c
commit
3a044e6b14
4 changed files with 134 additions and 115 deletions
|
@ -25,19 +25,10 @@ class DataEvaluator:
|
||||||
"""
|
"""
|
||||||
# 读取 Parquet 文件
|
# 读取 Parquet 文件
|
||||||
# dataset = load_dataset('parquet', data_files=file_path)
|
# dataset = load_dataset('parquet', data_files=file_path)
|
||||||
ds = load_dataset(file_path,"all")
|
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||||
df = pd.DataFrame(ds['test'])
|
'dev': 'all/dev-00000-of-00001.parquet',
|
||||||
# print(ds)
|
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||||
# # ds_1 = ds['train']
|
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||||
# ds_2 = ds['validation']
|
|
||||||
# ds_3 = ds['test']
|
|
||||||
# # 将数据集转换为 Pandas DataFrame
|
|
||||||
# df_test = pd.DataFrame(ds['test'])
|
|
||||||
# df_val = pd.DataFrame(ds['validation'])
|
|
||||||
|
|
||||||
# for _, row in df.iterrows():
|
|
||||||
# self.data.append(row.to_dict())
|
|
||||||
# df = pd.read_parquet(file_path)
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
self.data.append(row.to_dict())
|
self.data.append(row.to_dict())
|
||||||
|
|
|
@ -8,12 +8,57 @@ from datasets import load_dataset
|
||||||
import os
|
import os
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import threading
|
import threading
|
||||||
|
import re
|
||||||
|
|
||||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||||
os.environ['https_proxy'] = ''
|
os.environ['https_proxy'] = ''
|
||||||
os.environ['http_proxy'] = ''
|
os.environ['http_proxy'] = ''
|
||||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||||
|
|
||||||
|
|
||||||
|
def extract_final_answer(text):
|
||||||
|
"""
|
||||||
|
提取模型预测的最终选项(如 A/B/C/D)
|
||||||
|
支持自然语言、多行、markdown、高亮、非末尾结论等格式
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# 1. 显式语句匹配(优先)
|
||||||
|
explicit_patterns = [
|
||||||
|
r'Answer:\s*([A-D])\b',
|
||||||
|
r'Correct answer:\s*([A-D])\b',
|
||||||
|
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
|
||||||
|
r'Answer is\s*([A-D])\b',
|
||||||
|
r'Therefore,\s*answer is\s*([A-D])\b',
|
||||||
|
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||||
|
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||||
|
r'Option\s+([A-D])\s+is correct',
|
||||||
|
]
|
||||||
|
for pat in explicit_patterns:
|
||||||
|
match = re.search(pat, text, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).upper()
|
||||||
|
|
||||||
|
# 2. markdown 强调 **C**, **C. something**
|
||||||
|
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
|
||||||
|
if markdown_match:
|
||||||
|
return markdown_match[-1].upper()
|
||||||
|
|
||||||
|
# 3. 查找单引号中的 'C' 或 "C"
|
||||||
|
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
|
||||||
|
if quote_match:
|
||||||
|
return quote_match[-1].upper()
|
||||||
|
|
||||||
|
# 4. 倒数几行是否以 "C." 或 "C" 开头
|
||||||
|
lines = text.splitlines()
|
||||||
|
for line in reversed(lines[-5:]):
|
||||||
|
line = line.strip()
|
||||||
|
match = re.match(r'^([A-D])([.\s]|$)', line)
|
||||||
|
if match:
|
||||||
|
return match.group(1).upper()
|
||||||
|
|
||||||
|
# 再不行就返回 None
|
||||||
|
return None
|
||||||
class DataEvaluator:
|
class DataEvaluator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = []
|
self.data = []
|
||||||
|
@ -22,8 +67,10 @@ class DataEvaluator:
|
||||||
"""
|
"""
|
||||||
从数据文件中加载数据,每条记录对应一个实例
|
从数据文件中加载数据,每条记录对应一个实例
|
||||||
"""
|
"""
|
||||||
ds = load_dataset(file_path, "all")
|
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||||
df = pd.DataFrame(ds['test'])
|
'dev': 'all/dev-00000-of-00001.parquet',
|
||||||
|
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||||
|
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
self.data.append(row.to_dict())
|
self.data.append(row.to_dict())
|
||||||
|
|
||||||
|
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
|
||||||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||||
start_total_time = time.time()
|
start_total_time = time.time()
|
||||||
total_score = 0
|
total_score = 0
|
||||||
|
total_exact_score = 0
|
||||||
results = []
|
results = []
|
||||||
file_lock = threading.Lock()
|
file_lock = threading.Lock()
|
||||||
|
|
||||||
|
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
|
|
||||||
def worker(index, data_item):
|
def worker(index, data_item):
|
||||||
nonlocal total_score
|
nonlocal total_score
|
||||||
|
nonlocal total_exact_score
|
||||||
question = data_evaluator.get_prompt(data_item)
|
question = data_evaluator.get_prompt(data_item)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
|
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
answer = chr(data_item['answer'] + 65)
|
answer = chr(data_item['answer'] + 65)
|
||||||
processed_prediction = data_evaluator.post_processing(prediction)
|
processed_prediction = data_evaluator.post_processing(prediction)
|
||||||
score = data_evaluator.score(processed_prediction, answer)
|
score = data_evaluator.score(processed_prediction, answer)
|
||||||
|
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
result_data = {
|
result_data = {
|
||||||
"question_id": index,
|
"question_id": index,
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"prediction": processed_prediction,
|
"prediction": processed_prediction,
|
||||||
"real_prediction": prediction,
|
"full_prediction": prediction,
|
||||||
"score": score,
|
"score": score,
|
||||||
|
"exact_score": exact_score,
|
||||||
"time": elapsed_time
|
"time": elapsed_time
|
||||||
}
|
}
|
||||||
# 写入结果时加锁保证线程安全
|
# 写入结果时加锁保证线程安全
|
||||||
|
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
if res is not None:
|
if res is not None:
|
||||||
results.append(res)
|
results.append(res)
|
||||||
total_score += res['score']
|
total_score += res['score']
|
||||||
|
total_exact_score += res['exact_score']
|
||||||
|
|
||||||
total_time = time.time() - start_total_time
|
total_time = time.time() - start_total_time
|
||||||
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
||||||
|
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||||
average_score = total_score / len(data_subset) if data_subset else 0
|
average_score = total_score / len(data_subset) if data_subset else 0
|
||||||
log_f.write(f"Average Score: {average_score}\n")
|
log_f.write(f"Average Score: {average_score}\n")
|
||||||
|
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
|
||||||
|
log_f.write(f"Average Exact Score: {average_exact_score}\n")
|
||||||
log_f.write('-' * 40 + '\n')
|
log_f.write('-' * 40 + '\n')
|
||||||
|
|
||||||
print(f"Results saved to {result_file}")
|
print(f"Results saved to {result_file}")
|
||||||
|
@ -152,4 +206,4 @@ if __name__ == "__main__":
|
||||||
data_evaluator = DataEvaluator()
|
data_evaluator = DataEvaluator()
|
||||||
data_evaluator.load_data(args.file)
|
data_evaluator.load_data(args.file)
|
||||||
|
|
||||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
|
@ -2,23 +2,18 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
decodesz = 128
|
prompt_list = [
|
||||||
# Server URL (replace with your server URL)
|
'Please elaborate on modern world history.',
|
||||||
SERVER_URL = "http://localhost:10002/v1/chat/completions"
|
'Please introduce Harry Potter.',
|
||||||
bf_list = [1]
|
'I want to learn Python. Please give me some advice.',
|
||||||
decodesz_list = [128]
|
'Please tell me a joke '
|
||||||
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
]
|
||||||
async def fetch_event_stream(session, payload, request_id):
|
|
||||||
|
|
||||||
|
async def fetch_event_stream(session, payload, request_id, stream):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
|
@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id):
|
||||||
print(f"Request {request_id}: Error, status {response.status}")
|
print(f"Request {request_id}: Error, status {response.status}")
|
||||||
return
|
return
|
||||||
|
|
||||||
output_text = "" # 存储当前 response 的所有 token
|
output_text = ""
|
||||||
total_tokens = 0 # 统计总 tokens 数
|
|
||||||
decode_start_time = None # 记录 decode 阶段开始时间
|
|
||||||
decode_end_time = None # 记录 decode 结束时间
|
|
||||||
|
|
||||||
async for line in response.content:
|
if stream:
|
||||||
try:
|
async for line in response.content:
|
||||||
decoded_line = line.decode("utf-8").strip()
|
try:
|
||||||
|
decoded_line = line.decode("utf-8").strip()
|
||||||
|
if not decoded_line or not decoded_line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
|
||||||
# 过滤空行
|
decoded_line = decoded_line[6:].strip()
|
||||||
if not decoded_line or not decoded_line.startswith("data: "):
|
if not decoded_line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
|
response_data = json.loads(decoded_line)
|
||||||
|
choices = response_data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
|
||||||
# 确保 JSON 数据是合法的
|
delta = choices[0].get("delta", {})
|
||||||
if not decoded_line:
|
token = delta.get("content", "")
|
||||||
continue
|
|
||||||
|
|
||||||
response_data = json.loads(decoded_line) # 解析 JSON
|
if token:
|
||||||
|
output_text += token
|
||||||
|
sys.stdout.write(token)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
# 确保 choices 存在
|
finish_reason = choices[0].get("finish_reason", None)
|
||||||
choices = response_data.get("choices", [])
|
if finish_reason:
|
||||||
if not choices:
|
break
|
||||||
continue
|
|
||||||
|
|
||||||
delta = choices[0].get("delta", {})
|
except json.JSONDecodeError as e:
|
||||||
token = delta.get("content", "")
|
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
||||||
|
except IndexError:
|
||||||
if token:
|
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
||||||
if decode_start_time is None:
|
except Exception as e:
|
||||||
decode_start_time = time.time() # 记录 decode 开始时间
|
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
||||||
|
else:
|
||||||
output_text += token # 追加 token
|
# 非 stream 模式下,一次性接收完整 json
|
||||||
sys.stdout.write(token) # 直接输出 token
|
response_data = await response.json()
|
||||||
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
|
choices = response_data.get("choices", [])
|
||||||
total_tokens += 1 # 增加 token 计数
|
if choices:
|
||||||
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
|
content = choices[0].get("message", {}).get("content", "")
|
||||||
|
print(f"Request {request_id} Output:\n{content}")
|
||||||
# 检查是否完成
|
output_text += content
|
||||||
finish_reason = choices[0].get("finish_reason", None)
|
|
||||||
if finish_reason:
|
|
||||||
# print(f"\nRequest {request_id}: Done")
|
|
||||||
break # 结束流式处理
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
|
||||||
except IndexError:
|
|
||||||
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
|
||||||
|
|
||||||
# 计算 decode 速度
|
|
||||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
|
||||||
decode_time = decode_end_time - decode_start_time
|
|
||||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
|
||||||
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nRequest {request_id}: Exception - {e}")
|
print(f"\nRequest {request_id}: Exception - {e}")
|
||||||
|
|
||||||
async def main(prompt_id):
|
async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": ""},
|
{"role": "system", "content": ""},
|
||||||
{"role": "user", "content": prompt_list[prompt_id]}
|
{"role": "user", "content": prompt_list[prompt_id]}
|
||||||
],
|
],
|
||||||
"model": "DeepSeek-V3",
|
"model": model,
|
||||||
"stream": True,
|
"stream": stream,
|
||||||
"max_completion_tokens": 2,
|
"max_tokens": max_tokens,
|
||||||
# "temperature": 0.3,
|
"temperature": temperature,
|
||||||
# "top_p": 1.0,
|
"top_p": top_p
|
||||||
# "max_tokens" : 20,
|
|
||||||
}
|
}
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["temperature"] = 0.3
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["top_p"] = 1
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["max_tokens"] = 200
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["stream"] = False
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||||
parser.add_argument("--question_id", type=int, default=0, required=False)
|
parser.add_argument("--question_id", type=int, default=0)
|
||||||
|
parser.add_argument("--model", type=str, required=True)
|
||||||
|
parser.add_argument("--stream", type=bool, default=True)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=500)
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.8)
|
||||||
|
parser.add_argument("--top_p", type=float, default=1)
|
||||||
|
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
output_file = "ktransformer_test_results.txt"
|
SERVER_URL = args.api_url
|
||||||
asyncio.run(main(args.question_id))
|
asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
|
||||||
|
|
|
@ -45,14 +45,14 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
|
||||||
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
||||||
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
||||||
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
||||||
async def fetch_event_stream(session, request_id, prompt, max_tokens):
|
async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
|
||||||
try:
|
try:
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": ""},
|
{"role": "system", "content": ""},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
],
|
],
|
||||||
"model": "DeepSeek-V3",
|
"model": model,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
@ -134,17 +134,19 @@ async def fetch_event_stream(session, request_id, prompt, max_tokens):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Request {request_id}] Exception: {e}")
|
print(f"[Request {request_id}] Exception: {e}")
|
||||||
|
|
||||||
async def main(concurrent_requests , prompt, max_tokens):
|
async def main(concurrent_requests , prompt, max_tokens, model):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)]
|
tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
if len(prefill_speeds) != 0:
|
if len(prefill_speeds) != 0:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}")
|
print(f"concurrency: {len(prefill_speeds)}")
|
||||||
|
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||||
|
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
|
||||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||||
|
@ -152,9 +154,10 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SERVER_URL = args.api_url
|
SERVER_URL = args.api_url
|
||||||
max_tokens = args.max_tokens
|
max_tokens = args.max_tokens
|
||||||
|
model = args.model
|
||||||
if args.prompt_lens == 1024:
|
if args.prompt_lens == 1024:
|
||||||
prompt = ktansformer_prompt1024
|
prompt = ktansformer_prompt1024
|
||||||
elif args.prompt_lens == 2048:
|
elif args.prompt_lens == 2048:
|
||||||
prompt = ktansformer_prompt1024 * 2
|
prompt = ktansformer_prompt1024 * 2
|
||||||
asyncio.run(main(args.concurrent, prompt, max_tokens))
|
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue