Merge branch 'kvcache-ai:main' into main

This commit is contained in:
Yuhao Tsui 2025-03-10 09:10:28 +08:00 committed by GitHub
commit e5694f91c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 356 additions and 163 deletions

View file

@ -8,4 +8,4 @@ Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2025-02-15 03:53:02 LastEditTime : 2025-02-15 03:53:02
''' '''
__version__ = "0.2.3" __version__ = "0.2.3.post1"

View file

@ -175,6 +175,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
list(APPEND ARCH_FLAGS -mavx512bw) list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512dq) list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512vnni) list(APPEND ARCH_FLAGS -mavx512vnni)
list(APPEND ARCH_FLAGS -mavx512vpopcntdq)
endif() endif()
if (LLAMA_AVX512_BF16) if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16) list(APPEND ARCH_FLAGS -mavx512bf16)

View file

@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import os import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled: if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
logger = logging.getLogger("attention") logger = logging.getLogger("attention")

View file

@ -1,9 +1,11 @@
''' '''
Description : flashinfer MLA wrapper Description : flashinfer MLA wrapper
Author : Boxin Zhang Author : Boxin Zhang
Version : 0.2.2 Version : 0.2.3
''' '''
import torch import torch
import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
flashinfer_enabled = False flashinfer_enabled = False
@ -17,7 +19,7 @@ except ImportError:
import math import math
def attention_ref( def attention_ref_torch(
batch_size, batch_size,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
@ -139,11 +141,6 @@ class MLAWrapper():
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton(): class MLAWrapperSingleton():
@ -203,20 +200,58 @@ class MLAWrapperSingleton():
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
def checksame():
flashinfer_folder = "./flashinfer_output"
flashinfer_folder = "./kv_cache_flashinfer"
triton_folder = "./triton_output"
triton_folder = "./kv_cache_triton"
max_layer_id = 1
max_forward_id = 2
for forward_id in range(0, 19):
print("forward_id", forward_id)
for layer_id in range(max_layer_id):
print(layer_id)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name = f"layer_{layer_id}.pt"
flashinfer_path = os.path.join(flashinfer_folder, file_name)
triton_path = os.path.join(triton_folder, file_name)
if not os.path.exists(triton_path):
print(f"{file_name} not exist in {triton_folder}")
continue
if not os.path.exists(flashinfer_path):
print(f"{file_name} not exist in {flashinfer_folder}")
continue
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
try:
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
except AssertionError as e:
print(e)
if __name__ == "__main__": if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
#checksame()
#exit(0)
max_batch_size = 1 max_batch_size = 1
max_pages = 128 max_pages = 64
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
# warm-up
kv_len = 4023 kv_len = 4023
q_len = 1 q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda") ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
wrapper = MLAWrapperSingleton.get_instance( wrapper = MLAWrapperSingleton.get_instance(
@ -241,51 +276,105 @@ if __name__ == "__main__":
torch.bfloat16, torch.bfloat16,
) )
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
print(attn_output.shape) print(attn_output.shape)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
# warm-up finished
kv_len = 6789 for forward_id in range(0, 1):
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") print("forward_id", forward_id)
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") for layer_id in range(1):
wrapper.plan( print(layer_id)
qo_indptr, flashinfer_folder = "./kv_cache_flashinfer"
None, forward_id = 17
None, layer_id = 0
kv_len_arr, file_name = f"layer_{layer_id}.pt"
128, kv_cache_path = os.path.join(flashinfer_folder, file_name)
512, flashinfer_folder = "./flashinfer_output"
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
graph.replay() q_len = 1
kv_len = 126
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
q = torch.cat([q_nope, q_pe], dim=-1)
kv_cache = torch.load(kv_cache_path).to(device="cuda")
pages, page_size, _, head_dim = kv_cache.shape
kv_cache = kv_cache.view(pages, page_size, head_dim)
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
k = ( kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
torch.cat([ckv, k_pe], dim=-1) qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
.view(-1, 1, 512 + 64) wrapper.plan(
.repeat_interleave(num_heads, dim=1) None,
) None,
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
print(k[:kv_len].shape) q_nope_buf.copy_(q_nope)
print(v[:kv_len].shape) q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
attn_ref, lse_ref = attention_ref( torch.cuda.synchronize()
max_batch_size, graph.replay()
torch.cat([q_nope, q_pe], dim=-1), torch.cuda.synchronize()
k[:kv_len],
v[:kv_len], # ref_torch
True, k = (
192 ** (-0.5) torch.cat([ckv, k_pe], dim=-1)
) .view(-1, 1, 512 + 64)
print(attn_ref.shape) .repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:kv_len],
v[:kv_len],
False,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
# ref_triton
attn_logits = torch.empty(
(
max_batch_size,
num_heads,
4, #num_kv_splits # follow vLLM, fix it TODO
512 + 1,
),
dtype=torch.float32,
device = "cuda"
)
triton_ref = torch.zeros_like(q_nope)
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
ckv = ckv.view(pages, page_size, 1, 512)
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
page_table,
kv_len_arr, attn_logits,
4, #num_kv_splits # follow vLLM, fix it TODO
192 ** (-0.5),
page_size)
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past")
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")

View file

@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api') router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
@ -61,14 +63,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if input.stream: if input.stream:
async def inner(): async def inner():
async for token in interface.inference(input.prompt, id): async for res in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse( if isinstance(res, RawUsage):
model=config.model_name, raw_usage = res
created_at=str(datetime.now()), else:
response=token, token, finish_reason = res
done=False d = OllamaGenerationStreamResponse(
) model=config.model_name,
yield d.model_dump_json() + '\n' created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse( d = OllamaGenerationStreamResponse(
model=config.model_name, model=config.model_name,
created_at=str(datetime.now()), created_at=str(datetime.now()),
@ -142,14 +148,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count = 0 # 统计生成的 token 数量 eval_count = 0 # 统计生成的 token 数量
tokens = [] tokens = []
async for token in interface.inference(prompt, id): async for res in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse( if isinstance(res, RawUsage):
model=config.model_name, raw_usage = res
created_at=str(datetime.now()), else:
message={"role": "assistant", "content": token}, token, finish_reason = res
done=False d = OllamaChatCompletionStreamResponse(
) model=config.model_name,
yield d.model_dump_json() + '\n' created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据 # 计算性能数据
end_time = time() end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒 total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒

View file

@ -5,10 +5,16 @@ from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter() router = APIRouter()
@router.get('/models', tags=['openai']) @router.get('/models', tags=['openai'])
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(
async for token in interface.inference(input_message,id,create.temperature,create.top_p): id = id,
chunk.set_token(token) choices = [],
yield chunk object = 'chat.completion.chunk',
return chat_stream_response(request,inner()) created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk
return chat_stream_response(request, inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) from openai.types.chat.chat_completion import Choice
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) from openai.types.chat.chat_completion_message import ChatCompletionMessage
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) content = ""
return comp finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else:
token, finish_reason = res
content = content + token
finish_reason = finish_reason
choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion

View file

@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} if isinstance(res, RawUsage):
yield f"data:{json.dumps(d)}\n\n" raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token)
return comp return comp

View file

@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.schemas.base import ObjectID, Order from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
@ -142,12 +143,16 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id): async for res in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling: if isinstance(res, RawUsage):
logger.warn(f'Run {self.run.id} cancelling') raw_usage = res
break else:
yield reply_message.append_message_delta(token) token, finish_reason = res
response_str_count+=1 if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling: if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled) yield self.run.stream_response_with_event(RunObject.Status.cancelled)

View file

@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from typing import Optional from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False warm_uped = False
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p): async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)

View file

@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0): if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end() yield self.streamer.end(), "length"
return return
logger.info(f"max_new_tokens: {self.max_new_tokens}") logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
@ -344,14 +344,21 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc("decode") self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1 assert self.args.batch_size == 1
break break
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token), None
yield self.streamer.end()
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str): def check_is_new(self, thread_id: str):
if not self.use_static_cache: if not self.use_static_cache:
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if Config().user_force_think: if Config().user_force_think:
think = '<think>\n' think = '<think>\n'
print(think, end="",flush=True) print(think, end="",flush=True)
yield think yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done # output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t in self.generate(): for t, finish_reason in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, finish_reason
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()

View file

@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0 blessed >= 1.20.0
accelerate >= 0.31.0 accelerate >= 0.31.0
sentencepiece >= 0.1.97 sentencepiece >= 0.1.97
openai
setuptools setuptools
build build
ninja ninja

View file

@ -1,10 +1,15 @@
from typing import List, Optional from typing import List, Optional
from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
user = 'user' user = 'user'
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel): class ChatCompletionChunk(BaseModel):
index: int id: str
message: Message choices: List[Choice]
logprobs: Optional[str] = None created: int
finish_reason: FinishReason = None model: str
object: Literal["chat.completion.chunk"]
service_tier: Optional[Literal["scale", "default"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionBase(Object):
created:int
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[Usage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self): def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n" return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int

View file

@ -78,13 +78,15 @@ def run_eval_api(
format_tabs: bool = False, format_tabs: bool = False,
auth_token: str = None, auth_token: str = None,
problem_file: str = None, problem_file: str = None,
append: bool = False append: bool = False,
skip: int = 0
): ):
data = load_data(problem_file) data = load_data(problem_file)
pbar = tqdm.tqdm(total=len(data) * 1) pbar = tqdm.tqdm(total=len(data) * 1)
pbar.update(skip)
for i in range(len(data)): for i in range(len(data)):
i = i+skip
data_item = data[i] data_item = data[i]
question = data_item['Problem'] question = data_item['Problem']
# Start the timer for this evaluation # Start the timer for this evaluation
@ -97,6 +99,7 @@ def run_eval_api(
score = get_score(completion, answer) score = get_score(completion, answer)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result = { result = {
"index": i,
"question_id": data_item["ID"], "question_id": data_item["ID"],
"answer": answer, "answer": answer,
"prediction": completion, "prediction": completion,
@ -114,9 +117,9 @@ def run_eval_api(
pbar.update(1) pbar.update(1)
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append): def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append) run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
if __name__ == "__main__": if __name__ == "__main__":
@ -128,6 +131,7 @@ if __name__ == "__main__":
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs") parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File") parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File")
parser.add_argument("--no_append", action="store_false", help="Append to existing file") parser.add_argument("--no_append", action="store_false", help="Append to existing file")
parser.add_argument("--skip", type=int, default=0, help="Skip some tasks")
args = parser.parse_args() args = parser.parse_args()
# api_url = "https://api.siliconflow.cn/v1/chat/completions" # api_url = "https://api.siliconflow.cn/v1/chat/completions"
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append) main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip)

View file

@ -39,7 +39,8 @@ def run_eval_api(
format_tabs: bool = False, format_tabs: bool = False,
auth_token: str = None, auth_token: str = None,
problem_file: str = None, problem_file: str = None,
append: bool = False append: bool = False,
skip: int = 0
): ):
if(problem_file is None): if(problem_file is None):
problems = read_problems() problems = read_problems()
@ -47,8 +48,14 @@ def run_eval_api(
problems = read_problems(problem_file) problems = read_problems(problem_file)
samples = [] samples = []
pbar = tqdm.tqdm(total=len(problems) * 1) pbar = tqdm.tqdm(total=len(problems) * 1)
pbar.update(skip)
try: try:
for task_id in problems: for task_id in problems:
# skip some tasks
if skip > 0:
skip -= 1
continue
if format_tabs: if format_tabs:
prompt = problems[task_id]["prompt"].replace(" ", "\t") prompt = problems[task_id]["prompt"].replace(" ", "\t")
else: else:
@ -67,23 +74,26 @@ def run_eval_api(
if not append: if not append:
write_jsonl(out_path, samples,append=append) write_jsonl(out_path, samples,append=append)
except Exception as e: except Exception as e:
write_jsonl(out_path, samples,append=append) if not append:
write_jsonl(out_path, samples,append=append)
print(f"Error: {e}") print(f"Error: {e}")
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append): def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append) run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester") parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") #parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name") parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name")
parser.add_argument("--out_path", type=str, default="results/api/eval.jsonl", help="Output Path") parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path")
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token") parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs") parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
parser.add_argument("--problem_file", type=str, default=None, help="Evalset File") parser.add_argument("--problem_file", type=str, default=None, help="Evalset File")
parser.add_argument("--no_append", action="store_false", help="Append to existing file") parser.add_argument("--no_append", action="store_false", help="Append to existing file")
parser.add_argument("--skip", type=int, default=0, help="Skip first n problems")
args = parser.parse_args() args = parser.parse_args()
# api_url = "https://api.siliconflow.cn/v1/chat/completions" # api_url = "https://api.siliconflow.cn/v1/chat/completions"
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append) main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip)

View file

@ -8,7 +8,7 @@ def filter_code(completion: str) -> str:
completion = completion.split('if __name__ == "__main__":')[0] completion = completion.split('if __name__ == "__main__":')[0]
if "# Example usage" in completion: if "# Example usage" in completion:
completion = completion.split("# Example usage")[0] completion = completion.split("# Example usage")[0]
return completion.split("\n\n")[0] return completion
def fix_indents(text: str) -> str: def fix_indents(text: str) -> str:

View file

@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if use_flashinfer_mla: if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1, MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16) model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True warm_uped = True

View file

@ -2388,7 +2388,8 @@ struct SimpleBits {
struct EvenSignHelper { struct EvenSignHelper {
#ifdef HAVE_FANCY_SIMD #if defined HAVE_FANCY_SIMD
// #pragma message("Using AVX512VPOPCNTDQ in even sign helper")
union sbits_t { union sbits_t {
__m128i vec; __m128i vec;
__mmask32 mask[4]; __mmask32 mask[4];