Merge branch 'kvcache-ai:main' into main

This commit is contained in:
Yuhao Tsui 2025-03-10 09:10:28 +08:00 committed by GitHub
commit e5694f91c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 356 additions and 163 deletions

View file

@ -8,4 +8,4 @@ Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2025-02-15 03:53:02
'''
__version__ = "0.2.3"
__version__ = "0.2.3.post1"

View file

@ -175,6 +175,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512vnni)
list(APPEND ARCH_FLAGS -mavx512vpopcntdq)
endif()
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)

View file

@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
logger = logging.getLogger("attention")

View file

@ -1,9 +1,11 @@
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
Version : 0.2.3
'''
import torch
import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
flashinfer_enabled = False
@ -17,7 +19,7 @@ except ImportError:
import math
def attention_ref(
def attention_ref_torch(
batch_size,
q: torch.Tensor,
k: torch.Tensor,
@ -122,7 +124,7 @@ class MLAWrapper():
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
@ -139,11 +141,6 @@ class MLAWrapper():
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
@ -202,21 +199,59 @@ class MLAWrapperSingleton():
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
def checksame():
flashinfer_folder = "./flashinfer_output"
flashinfer_folder = "./kv_cache_flashinfer"
triton_folder = "./triton_output"
triton_folder = "./kv_cache_triton"
max_layer_id = 1
max_forward_id = 2
for forward_id in range(0, 19):
print("forward_id", forward_id)
for layer_id in range(max_layer_id):
print(layer_id)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name = f"layer_{layer_id}.pt"
flashinfer_path = os.path.join(flashinfer_folder, file_name)
triton_path = os.path.join(triton_folder, file_name)
if not os.path.exists(triton_path):
print(f"{file_name} not exist in {triton_folder}")
continue
if not os.path.exists(flashinfer_path):
print(f"{file_name} not exist in {flashinfer_folder}")
continue
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
try:
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
except AssertionError as e:
print(e)
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
#checksame()
#exit(0)
max_batch_size = 1
max_pages = 128
max_pages = 64
page_size = 64
num_heads = 128
# warm-up
kv_len = 4023
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
wrapper = MLAWrapperSingleton.get_instance(
@ -241,51 +276,105 @@ if __name__ == "__main__":
torch.bfloat16,
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
kv_len = 6789
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
# warm-up finished
for forward_id in range(0, 1):
print("forward_id", forward_id)
for layer_id in range(1):
print(layer_id)
flashinfer_folder = "./kv_cache_flashinfer"
forward_id = 17
layer_id = 0
file_name = f"layer_{layer_id}.pt"
kv_cache_path = os.path.join(flashinfer_folder, file_name)
flashinfer_folder = "./flashinfer_output"
q_len = 1
kv_len = 126
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
q = torch.cat([q_nope, q_pe], dim=-1)
kv_cache = torch.load(kv_cache_path).to(device="cuda")
pages, page_size, _, head_dim = kv_cache.shape
kv_cache = kv_cache.view(pages, page_size, head_dim)
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
graph.replay()
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:kv_len].shape)
print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref(
max_batch_size,
torch.cat([q_nope, q_pe], dim=-1),
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
print(attn_ref.shape)
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
None,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
q_nope_buf.copy_(q_nope)
q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
# ref_torch
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:kv_len],
v[:kv_len],
False,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
# ref_triton
attn_logits = torch.empty(
(
max_batch_size,
num_heads,
4, #num_kv_splits # follow vLLM, fix it TODO
512 + 1,
),
dtype=torch.float32,
device = "cuda"
)
triton_ref = torch.zeros_like(q_nope)
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
ckv = ckv.view(pages, page_size, 1, 512)
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
page_table,
kv_len_arr, attn_logits,
4, #num_kv_splits # follow vLLM, fix it TODO
192 ** (-0.5),
page_size)
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past")
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")

View file

@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
@ -61,14 +63,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if input.stream:
async def inner():
async for token in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
@ -142,14 +148,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count = 0 # 统计生成的 token 数量
tokens = []
async for token in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒

View file

@ -5,10 +5,16 @@ from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter()
@router.get('/models', tags=['openai'])
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token)
yield chunk
return chat_stream_response(request,inner())
chunk = ChatCompletionChunk(
id = id,
choices = [],
object = 'chat.completion.chunk',
created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk
return chat_stream_response(request, inner())
else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token)
return comp
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
content = ""
finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else:
token, finish_reason = res
content = content + token
finish_reason = finish_reason
choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion

View file

@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter()
@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream:
async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner())
else:
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token)
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token)
return comp

View file

@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler
@ -142,12 +143,16 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
async for res in self.interface.inference(local_messages,self.thread.id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled)

View file

@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)

View file

@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end()
yield self.streamer.end(), "length"
return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0)
@ -344,14 +344,21 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
yield self.streamer.end()
yield self.append_new_tokens(next_token), None
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str):
if not self.use_static_cache:
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t in self.generate():
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t
yield t, finish_reason
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()

View file

@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja

View file

@ -1,10 +1,15 @@
from typing import List, Optional
from typing_extensions import Literal
from enum import Enum
from pydantic import BaseModel
from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
class Role(Enum):
system = 'system'
user = 'user'
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class ChatCompletionChunk(BaseModel):
id: str
choices: List[Choice]
created: int
model: str
object: Literal["chat.completion.chunk"]
service_tier: Optional[Literal["scale", "default"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionBase(Object):
created:int
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[Usage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int

View file

@ -78,13 +78,15 @@ def run_eval_api(
format_tabs: bool = False,
auth_token: str = None,
problem_file: str = None,
append: bool = False
append: bool = False,
skip: int = 0
):
data = load_data(problem_file)
pbar = tqdm.tqdm(total=len(data) * 1)
pbar.update(skip)
for i in range(len(data)):
i = i+skip
data_item = data[i]
question = data_item['Problem']
# Start the timer for this evaluation
@ -97,6 +99,7 @@ def run_eval_api(
score = get_score(completion, answer)
elapsed_time = time.time() - start_time
result = {
"index": i,
"question_id": data_item["ID"],
"answer": answer,
"prediction": completion,
@ -114,9 +117,9 @@ def run_eval_api(
pbar.update(1)
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
if __name__ == "__main__":
@ -128,6 +131,7 @@ if __name__ == "__main__":
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File")
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
parser.add_argument("--skip", type=int, default=0, help="Skip some tasks")
args = parser.parse_args()
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip)

View file

@ -39,7 +39,8 @@ def run_eval_api(
format_tabs: bool = False,
auth_token: str = None,
problem_file: str = None,
append: bool = False
append: bool = False,
skip: int = 0
):
if(problem_file is None):
problems = read_problems()
@ -47,8 +48,14 @@ def run_eval_api(
problems = read_problems(problem_file)
samples = []
pbar = tqdm.tqdm(total=len(problems) * 1)
pbar.update(skip)
try:
for task_id in problems:
# skip some tasks
if skip > 0:
skip -= 1
continue
if format_tabs:
prompt = problems[task_id]["prompt"].replace(" ", "\t")
else:
@ -67,23 +74,26 @@ def run_eval_api(
if not append:
write_jsonl(out_path, samples,append=append)
except Exception as e:
write_jsonl(out_path, samples,append=append)
if not append:
write_jsonl(out_path, samples,append=append)
print(f"Error: {e}")
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name")
parser.add_argument("--out_path", type=str, default="results/api/eval.jsonl", help="Output Path")
parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path")
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
parser.add_argument("--problem_file", type=str, default=None, help="Evalset File")
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
parser.add_argument("--skip", type=int, default=0, help="Skip first n problems")
args = parser.parse_args()
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip)

View file

@ -8,7 +8,7 @@ def filter_code(completion: str) -> str:
completion = completion.split('if __name__ == "__main__":')[0]
if "# Example usage" in completion:
completion = completion.split("# Example usage")[0]
return completion.split("\n\n")[0]
return completion
def fix_indents(text: str) -> str:

View file

@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True

View file

@ -2388,7 +2388,8 @@ struct SimpleBits {
struct EvenSignHelper {
#ifdef HAVE_FANCY_SIMD
#if defined HAVE_FANCY_SIMD
// #pragma message("Using AVX512VPOPCNTDQ in even sign helper")
union sbits_t {
__m128i vec;
__mmask32 mask[4];