mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
Merge branch 'kvcache-ai:main' into main
This commit is contained in:
commit
e5694f91c0
17 changed files with 356 additions and 163 deletions
|
@ -8,4 +8,4 @@ Version : 1.0.0
|
|||
LastEditors : chenxl
|
||||
LastEditTime : 2025-02-15 03:53:02
|
||||
'''
|
||||
__version__ = "0.2.3"
|
||||
__version__ = "0.2.3.post1"
|
||||
|
|
|
@ -175,6 +175,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
list(APPEND ARCH_FLAGS -mavx512bw)
|
||||
list(APPEND ARCH_FLAGS -mavx512dq)
|
||||
list(APPEND ARCH_FLAGS -mavx512vnni)
|
||||
list(APPEND ARCH_FLAGS -mavx512vpopcntdq)
|
||||
endif()
|
||||
if (LLAMA_AVX512_BF16)
|
||||
list(APPEND ARCH_FLAGS -mavx512bf16)
|
||||
|
|
|
@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
|
|||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
'''
|
||||
Description : flashinfer MLA wrapper
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.2
|
||||
Version : 0.2.3
|
||||
'''
|
||||
import torch
|
||||
import os
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
|
||||
flashinfer_enabled = False
|
||||
|
||||
|
@ -17,7 +19,7 @@ except ImportError:
|
|||
|
||||
import math
|
||||
|
||||
def attention_ref(
|
||||
def attention_ref_torch(
|
||||
batch_size,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
@ -122,7 +124,7 @@ class MLAWrapper():
|
|||
if kv_indices is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indices = self.kv_indices_buf
|
||||
|
||||
|
||||
self.wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
|
@ -139,11 +141,6 @@ class MLAWrapper():
|
|||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
#print("run")
|
||||
#print(self.wrapper._qo_indptr_buf)
|
||||
#print(self.wrapper._kv_indptr_buf)
|
||||
#print(self.wrapper._kv_indices_buf)
|
||||
#print(self.wrapper._kv_len_arr_buf)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
|
@ -202,21 +199,59 @@ class MLAWrapperSingleton():
|
|||
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
|
||||
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
|
||||
|
||||
def checksame():
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
triton_folder = "./triton_output"
|
||||
triton_folder = "./kv_cache_triton"
|
||||
|
||||
max_layer_id = 1
|
||||
max_forward_id = 2
|
||||
|
||||
for forward_id in range(0, 19):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(max_layer_id):
|
||||
print(layer_id)
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
|
||||
flashinfer_path = os.path.join(flashinfer_folder, file_name)
|
||||
triton_path = os.path.join(triton_folder, file_name)
|
||||
|
||||
if not os.path.exists(triton_path):
|
||||
print(f"{file_name} not exist in {triton_folder}")
|
||||
continue
|
||||
if not os.path.exists(flashinfer_path):
|
||||
print(f"{file_name} not exist in {flashinfer_folder}")
|
||||
continue
|
||||
|
||||
|
||||
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
|
||||
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
|
||||
try:
|
||||
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
#checksame()
|
||||
#exit(0)
|
||||
|
||||
max_batch_size = 1
|
||||
max_pages = 128
|
||||
max_pages = 64
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
|
||||
# warm-up
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
|
||||
|
||||
|
||||
wrapper = MLAWrapperSingleton.get_instance(
|
||||
|
@ -241,51 +276,105 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
|
||||
kv_len = 6789
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
# warm-up finished
|
||||
|
||||
for forward_id in range(0, 1):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(1):
|
||||
print(layer_id)
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
forward_id = 17
|
||||
layer_id = 0
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
kv_cache_path = os.path.join(flashinfer_folder, file_name)
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
|
||||
q_len = 1
|
||||
kv_len = 126
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
|
||||
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv_cache = torch.load(kv_cache_path).to(device="cuda")
|
||||
pages, page_size, _, head_dim = kv_cache.shape
|
||||
kv_cache = kv_cache.view(pages, page_size, head_dim)
|
||||
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
|
||||
|
||||
graph.replay()
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
|
||||
print(k[:kv_len].shape)
|
||||
print(v[:kv_len].shape)
|
||||
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
max_batch_size,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
q_nope_buf.copy_(q_nope)
|
||||
q_pe_buf.copy_(q_pe)
|
||||
kv_buf[:pages].copy_(kv_cache)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ref_torch
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
max_batch_size,
|
||||
q,
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# ref_triton
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
max_batch_size,
|
||||
num_heads,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
512 + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device = "cuda"
|
||||
)
|
||||
|
||||
triton_ref = torch.zeros_like(q_nope)
|
||||
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
|
||||
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
|
||||
ckv = ckv.view(pages, page_size, 1, 512)
|
||||
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
|
||||
page_table,
|
||||
kv_len_arr, attn_logits,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
192 ** (-0.5),
|
||||
page_size)
|
||||
|
||||
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#ktrans_output = torch.load(file_name)
|
||||
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
|
@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
|
|||
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
router = APIRouter(prefix='/api')
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
|
@ -61,14 +63,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
|||
|
||||
if input.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(input.prompt, id):
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response=token,
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
async for res in interface.inference(input.prompt, id):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response=token,
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
|
@ -142,14 +148,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
|
|||
eval_count = 0 # 统计生成的 token 数量
|
||||
tokens = []
|
||||
|
||||
async for token in interface.inference(prompt, id):
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={"role": "assistant", "content": token},
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
async for res in interface.inference(prompt, id):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={"role": "assistant", "content": token},
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
# 计算性能数据
|
||||
end_time = time()
|
||||
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
|
||||
|
|
|
@ -5,10 +5,16 @@ from fastapi import APIRouter
|
|||
from fastapi.requests import Request
|
||||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get('/models', tags=['openai'])
|
||||
|
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
|||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||
|
||||
if create.stream:
|
||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||
|
||||
async def inner():
|
||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
chunk.set_token(token)
|
||||
yield chunk
|
||||
return chat_stream_response(request,inner())
|
||||
chunk = ChatCompletionChunk(
|
||||
id = id,
|
||||
choices = [],
|
||||
object = 'chat.completion.chunk',
|
||||
created = int(time()),
|
||||
model = Config().model_name,
|
||||
)
|
||||
|
||||
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
# at the end of inference, interface.inference() will return the usage of inference
|
||||
raw_usage = res
|
||||
chunk.choices = []
|
||||
chunk.usage = CompletionUsage(
|
||||
prompt_tokens = raw_usage.prefill_count,
|
||||
completion_tokens = raw_usage.decode_count,
|
||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
else:
|
||||
token, finish_reason = res
|
||||
choice = Choice(
|
||||
index = 0,
|
||||
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
|
||||
finish_reason = finish_reason,
|
||||
logprobs = None,
|
||||
)
|
||||
chunk.choices = [choice]
|
||||
yield chunk
|
||||
|
||||
return chat_stream_response(request, inner())
|
||||
else:
|
||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
comp.append_token(token)
|
||||
return comp
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
content = ""
|
||||
finish_reason = None
|
||||
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens = raw_usage.prefill_count,
|
||||
completion_tokens = raw_usage.decode_count,
|
||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||
)
|
||||
else:
|
||||
token, finish_reason = res
|
||||
content = content + token
|
||||
finish_reason = finish_reason
|
||||
|
||||
choice = Choice(
|
||||
index = 0,
|
||||
finish_reason = finish_reason,
|
||||
message = ChatCompletionMessage(
|
||||
content=content,
|
||||
role="assistant"
|
||||
))
|
||||
|
||||
chat_completion = ChatCompletion(
|
||||
id = id,
|
||||
choices = [choice],
|
||||
created = int(time()),
|
||||
model = Config().model_name,
|
||||
object = 'chat.completion',
|
||||
usage = usage
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
|
|
@ -6,6 +6,7 @@ from fastapi.requests import Request
|
|||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import stream_response
|
||||
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
|
|||
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
|
||||
|
||||
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
d = {'choices':[{'delta':{'content':token}}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = {'choices':[{'delta':{'content':token}}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
return stream_response(request,inner())
|
||||
else:
|
||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
comp.append_token(token)
|
||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
comp.append_token(token)
|
||||
return comp
|
||||
|
|
|
@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject
|
|||
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
|
||||
from ktransformers.server.schemas.assistants.runs import RunObject
|
||||
from ktransformers.server.schemas.assistants.threads import ThreadObject
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.schemas.base import ObjectID, Order
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
|
||||
|
@ -142,12 +143,16 @@ class ThreadContext:
|
|||
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
|
||||
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
|
||||
|
||||
async for token in self.interface.inference(local_messages,self.thread.id):
|
||||
if self.run.status == RunObject.Status.cancelling:
|
||||
logger.warn(f'Run {self.run.id} cancelling')
|
||||
break
|
||||
yield reply_message.append_message_delta(token)
|
||||
response_str_count+=1
|
||||
async for res in self.interface.inference(local_messages,self.thread.id):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
if self.run.status == RunObject.Status.cancelling:
|
||||
logger.warn(f'Run {self.run.id} cancelling')
|
||||
break
|
||||
yield reply_message.append_message_delta(token)
|
||||
response_str_count+=1
|
||||
|
||||
if self.run.status == RunObject.Status.cancelling:
|
||||
yield self.run.stream_response_with_event(RunObject.Status.cancelled)
|
||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
|
|||
from ktransformers.util.utils import get_device
|
||||
from typing import Optional
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
|
|||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||
yield v
|
||||
|
||||
# return this inference raw usage
|
||||
yield RawUsage(
|
||||
tokenize_time = self.profiler.get_timer_sec('tokenize'),
|
||||
prefill_time = self.profiler.get_timer_sec('prefill'),
|
||||
decode_time = self.profiler.get_timer_sec('decode'),
|
||||
prefill_count = self.profiler.get_counter('prefill'),
|
||||
decode_count = self.profiler.get_counter('decode'),
|
||||
)
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||
if(self.max_new_tokens <= 0):
|
||||
logger.warning("max_new_tokens is less than 0")
|
||||
yield self.streamer.end()
|
||||
yield self.streamer.end(), "length"
|
||||
return
|
||||
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
||||
self.profiler.set_counter("decode", 0)
|
||||
|
@ -344,14 +344,21 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
yield self.streamer.end(), None
|
||||
yield "", "stop"
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
yield self.streamer.end()
|
||||
yield self.append_new_tokens(next_token), None
|
||||
|
||||
else: # for's else, if output get max new tokens
|
||||
yield self.streamer.end(), None
|
||||
yield "", "length"
|
||||
|
||||
|
||||
|
||||
def check_is_new(self, thread_id: str):
|
||||
if not self.use_static_cache:
|
||||
|
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t in self.generate():
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
|
|
@ -5,6 +5,7 @@ langchain >= 0.2.0
|
|||
blessed >= 1.20.0
|
||||
accelerate >= 0.31.0
|
||||
sentencepiece >= 0.1.97
|
||||
openai
|
||||
setuptools
|
||||
build
|
||||
ninja
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
from typing import List, Optional
|
||||
from typing_extensions import Literal
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ktransformers.server.schemas.base import Object
|
||||
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.types.chat.chat_completion_chunk import Choice
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
system = 'system'
|
||||
user = 'user'
|
||||
|
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
|
|||
def get_tokenizer_messages(self):
|
||||
return [m.to_tokenizer_message() for m in self.messages]
|
||||
|
||||
class FinishReason(Enum):
|
||||
stop = 'stop'
|
||||
length = 'length'
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Optional[str] = None
|
||||
finish_reason: FinishReason = None
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
choices: List[Choice]
|
||||
created: int
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
service_tier: Optional[Literal["scale", "default"]] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
usage: Optional[CompletionUsage] = None
|
||||
|
||||
class DeltaChoice(BaseModel):
|
||||
index: int
|
||||
delta: Message
|
||||
logprobs: Optional[str] = None
|
||||
finish_reason: FinishReason = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
completion_tokens:int
|
||||
prompt_tokens:int
|
||||
total_tokens:int
|
||||
|
||||
|
||||
class ChatCompletionBase(Object):
|
||||
created:int
|
||||
model:str = 'not implmented'
|
||||
system_fingerprint:str = 'not implmented'
|
||||
usage: Optional[Usage] = None
|
||||
|
||||
class ChatCompletionObject(ChatCompletionBase):
|
||||
choices:List[Choice] = []
|
||||
|
||||
def append_token(self,token:str):
|
||||
if len(self.choices) == 0:
|
||||
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
|
||||
self.choices[0].message.content += token
|
||||
|
||||
class ChatCompletionChunk(ChatCompletionBase):
|
||||
choices:List[DeltaChoice] = []
|
||||
|
||||
def set_token(self,token:str):
|
||||
self.choices = [
|
||||
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
|
||||
]
|
||||
|
||||
def to_stream_reply(self):
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class RawUsage(BaseModel):
|
||||
tokenize_time: float
|
||||
prefill_time: float
|
||||
decode_time: float
|
||||
prefill_count: int
|
||||
decode_count: int
|
||||
|
|
|
@ -78,13 +78,15 @@ def run_eval_api(
|
|||
format_tabs: bool = False,
|
||||
auth_token: str = None,
|
||||
problem_file: str = None,
|
||||
append: bool = False
|
||||
append: bool = False,
|
||||
skip: int = 0
|
||||
):
|
||||
|
||||
data = load_data(problem_file)
|
||||
pbar = tqdm.tqdm(total=len(data) * 1)
|
||||
|
||||
pbar.update(skip)
|
||||
for i in range(len(data)):
|
||||
i = i+skip
|
||||
data_item = data[i]
|
||||
question = data_item['Problem']
|
||||
# Start the timer for this evaluation
|
||||
|
@ -97,6 +99,7 @@ def run_eval_api(
|
|||
score = get_score(completion, answer)
|
||||
elapsed_time = time.time() - start_time
|
||||
result = {
|
||||
"index": i,
|
||||
"question_id": data_item["ID"],
|
||||
"answer": answer,
|
||||
"prediction": completion,
|
||||
|
@ -114,9 +117,9 @@ def run_eval_api(
|
|||
pbar.update(1)
|
||||
|
||||
|
||||
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
|
||||
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
|
||||
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -128,6 +131,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
|
||||
parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File")
|
||||
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
|
||||
parser.add_argument("--skip", type=int, default=0, help="Skip some tasks")
|
||||
args = parser.parse_args()
|
||||
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
|
||||
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
|
||||
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip)
|
|
@ -39,7 +39,8 @@ def run_eval_api(
|
|||
format_tabs: bool = False,
|
||||
auth_token: str = None,
|
||||
problem_file: str = None,
|
||||
append: bool = False
|
||||
append: bool = False,
|
||||
skip: int = 0
|
||||
):
|
||||
if(problem_file is None):
|
||||
problems = read_problems()
|
||||
|
@ -47,8 +48,14 @@ def run_eval_api(
|
|||
problems = read_problems(problem_file)
|
||||
samples = []
|
||||
pbar = tqdm.tqdm(total=len(problems) * 1)
|
||||
pbar.update(skip)
|
||||
try:
|
||||
for task_id in problems:
|
||||
# skip some tasks
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
|
||||
if format_tabs:
|
||||
prompt = problems[task_id]["prompt"].replace(" ", "\t")
|
||||
else:
|
||||
|
@ -67,23 +74,26 @@ def run_eval_api(
|
|||
if not append:
|
||||
write_jsonl(out_path, samples,append=append)
|
||||
except Exception as e:
|
||||
write_jsonl(out_path, samples,append=append)
|
||||
if not append:
|
||||
write_jsonl(out_path, samples,append=append)
|
||||
print(f"Error: {e}")
|
||||
|
||||
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
|
||||
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
|
||||
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name")
|
||||
parser.add_argument("--out_path", type=str, default="results/api/eval.jsonl", help="Output Path")
|
||||
parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path")
|
||||
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
|
||||
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
|
||||
parser.add_argument("--problem_file", type=str, default=None, help="Evalset File")
|
||||
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
|
||||
parser.add_argument("--skip", type=int, default=0, help="Skip first n problems")
|
||||
args = parser.parse_args()
|
||||
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
|
||||
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
|
||||
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip)
|
|
@ -8,7 +8,7 @@ def filter_code(completion: str) -> str:
|
|||
completion = completion.split('if __name__ == "__main__":')[0]
|
||||
if "# Example usage" in completion:
|
||||
completion = completion.split("# Example usage")[0]
|
||||
return completion.split("\n\n")[0]
|
||||
return completion
|
||||
|
||||
|
||||
def fix_indents(text: str) -> str:
|
||||
|
|
|
@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
|
|
3
third_party/llamafile/iqk_mul_mat.inc
vendored
3
third_party/llamafile/iqk_mul_mat.inc
vendored
|
@ -2388,7 +2388,8 @@ struct SimpleBits {
|
|||
|
||||
|
||||
struct EvenSignHelper {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#if defined HAVE_FANCY_SIMD
|
||||
// #pragma message("Using AVX512VPOPCNTDQ in even sign helper")
|
||||
union sbits_t {
|
||||
__m128i vec;
|
||||
__mmask32 mask[4];
|
||||
|
|
Loading…
Add table
Reference in a new issue