mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Merge branch 'main' of https://github.com/KMSorSMS/ktransformers into main
This commit is contained in:
commit
80e0536fb0
20 changed files with 231 additions and 53 deletions
39
.github/ISSUE_TEMPLATE/-bug-.yaml
vendored
Normal file
39
.github/ISSUE_TEMPLATE/-bug-.yaml
vendored
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
name: 🐞 Bug report
|
||||||
|
description: Create a report to help us reproduce and fix the bug
|
||||||
|
title: "[Bug] "
|
||||||
|
labels: ['Bug']
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: 1. I have searched related issues but cannot get the expected help.
|
||||||
|
- label: 2. The bug has not been fixed in the latest version.
|
||||||
|
- label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
|
||||||
|
- label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
|
||||||
|
- label: 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Describe the bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Reproduction
|
||||||
|
description: |
|
||||||
|
What command or script did you run? Which **model** are you using?
|
||||||
|
placeholder: |
|
||||||
|
A placeholder for the command.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: |
|
||||||
|
Please provide necessary environment information here (e.g. OS/GPU/CPU). Otherwise the issue will be close.
|
||||||
|
placeholder: Environment here.
|
||||||
|
validations:
|
||||||
|
required: true
|
39
.github/ISSUE_TEMPLATE/-bug2-.yaml
vendored
Normal file
39
.github/ISSUE_TEMPLATE/-bug2-.yaml
vendored
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
name: 🐞 BUG报告
|
||||||
|
description: 创建报告以帮助我们复现并修复BUG
|
||||||
|
title: "[Bug] "
|
||||||
|
labels: ['Bug']
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: 检查清单
|
||||||
|
options:
|
||||||
|
- label: 1. 我已经搜索过相关问题,但未能获得预期的帮助
|
||||||
|
- label: 2. 该问题在最新版本中尚未修复
|
||||||
|
- label: 3. 请注意,如果您提交的BUG相关 issue 缺少对应环境信息和最小可复现示例,我们将难以复现和定位问题,降低获得反馈的可能性
|
||||||
|
- label: 4. 如果您提出的不是bug而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
|
||||||
|
- label: 5. 为方便社区交流,我将使用中文/英文或附上中文/英文翻译(如使用其他语言)。未附带翻译的非中文/英语内容可能会被关闭
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 问题描述
|
||||||
|
description: 清晰简洁地描述BUG是什么
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 复现步骤
|
||||||
|
description: |
|
||||||
|
你运行了什么命令或脚本?使用的是哪个**模型**?
|
||||||
|
placeholder: |
|
||||||
|
在此处填写命令
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 环境信息
|
||||||
|
description: |
|
||||||
|
请提供必要的环境信息(如操作系统/GPU/CPU),否则该 issue 将被关闭
|
||||||
|
placeholder: 在此处填写环境信息
|
||||||
|
validations:
|
||||||
|
required: true
|
23
.github/ISSUE_TEMPLATE/-feature-.yaml
vendored
Normal file
23
.github/ISSUE_TEMPLATE/-feature-.yaml
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
name: 🚀 Feature request
|
||||||
|
description: Suggest an idea for this project
|
||||||
|
title: "[Feature] "
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
|
||||||
|
- label: 2. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-English/Chinese content without translation may be closed.
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Motivation
|
||||||
|
description: |
|
||||||
|
A clear and concise description of the motivation of the feature.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Related resources
|
||||||
|
description: |
|
||||||
|
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
|
23
.github/ISSUE_TEMPLATE/-feature2-.yaml
vendored
Normal file
23
.github/ISSUE_TEMPLATE/-feature2-.yaml
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
name: 🚀 新功能请求
|
||||||
|
description: 为项目提出新功能建议
|
||||||
|
title: "[Feature] "
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: 检查清单
|
||||||
|
options:
|
||||||
|
- label: 1. 如果您提出的不是新功能而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
|
||||||
|
- label: 2. 为方便社区交流,我将使用中文/英文或附上英文/中文翻译(如使用其他语言)。未附带翻译的非英文/中文内容可能会被关闭
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 需求背景
|
||||||
|
description: |
|
||||||
|
清晰简洁地描述该功能的背景需求
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 相关资源
|
||||||
|
description: |
|
||||||
|
如果有官方代码实现或第三方实现,请在此提供相关信息,这将非常有帮助
|
|
@ -209,6 +209,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
include_directories("$ENV{CUDA_PATH}/include")
|
include_directories("$ENV{CUDA_PATH}/include")
|
||||||
|
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||||
elseif (UNIX)
|
elseif (UNIX)
|
||||||
if (KTRANSFORMERS_USE_CUDA)
|
if (KTRANSFORMERS_USE_CUDA)
|
||||||
find_package(CUDA REQUIRED)
|
find_package(CUDA REQUIRED)
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
|
|
||||||
#include "kvcache.h"
|
#include "kvcache.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
|
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
|
||||||
float *attn_lse, int batch_size,
|
float *attn_lse, int batch_size,
|
||||||
Backend *backend) {
|
Backend *backend) {
|
||||||
|
|
|
@ -9,6 +9,9 @@
|
||||||
**/
|
**/
|
||||||
|
|
||||||
#include "kvcache.h"
|
#include "kvcache.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
|
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
|
||||||
// Timer start
|
// Timer start
|
||||||
auto start = std::chrono::high_resolution_clock::now();
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
|
|
||||||
#include "kvcache.h"
|
#include "kvcache.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
|
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
|
||||||
int block_idx, Backend *backend) {
|
int block_idx, Backend *backend) {
|
||||||
// Timer start
|
// Timer start
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
|
|
||||||
#include "kvcache.h"
|
#include "kvcache.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
std::string ggml_type_to_string(ggml_type type) {
|
std::string ggml_type_to_string(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
|
|
@ -111,11 +111,11 @@ def local_chat(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(f"generation config can't auto create, make default. Message: {e}")
|
||||||
gen_config = GenerationConfig(
|
gen_config = GenerationConfig(
|
||||||
max_length=128,
|
temperature=0.6,
|
||||||
temperature=0.7,
|
top_p=0.95,
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True
|
do_sample=True
|
||||||
)
|
)
|
||||||
model.generation_config = gen_config
|
model.generation_config = gen_config
|
||||||
|
|
|
@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
def forward_linux_flashinfer(
|
def forward_linux_flashinfer_chunk(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
def forward_linux_flashinfer(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if q_len <= self.chunck_size or not self.absorb_for_prefill:
|
||||||
|
return self.forward_linux_flashinfer_chunk(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
def forward_windows(
|
def forward_windows(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
@ -139,6 +139,11 @@ class MLAWrapper():
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||||
|
#print("run")
|
||||||
|
#print(self.wrapper._qo_indptr_buf)
|
||||||
|
#print(self.wrapper._kv_indptr_buf)
|
||||||
|
#print(self.wrapper._kv_indices_buf)
|
||||||
|
#print(self.wrapper._kv_len_arr_buf)
|
||||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||||
|
|
||||||
class MLAWrapperSingleton():
|
class MLAWrapperSingleton():
|
||||||
|
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
max_batch_size = 1
|
max_batch_size = 1
|
||||||
max_pages = 1
|
max_pages = 128
|
||||||
page_size = 64
|
page_size = 64
|
||||||
num_heads = 128
|
num_heads = 128
|
||||||
|
|
||||||
q_len = 10
|
kv_len = 2069
|
||||||
|
q_len = 1
|
||||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
@ -218,7 +224,7 @@ if __name__ == "__main__":
|
||||||
max_pages,
|
max_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
|
@ -244,15 +250,15 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||||
|
|
||||||
print(k[:10].shape)
|
print(k[:kv_len].shape)
|
||||||
print(v[:10].shape)
|
print(v[:kv_len].shape)
|
||||||
|
|
||||||
attn_ref, lse_ref = attention_ref(
|
attn_ref, lse_ref = attention_ref(
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
torch.cat([q_nope, q_pe], dim=-1),
|
torch.cat([q_nope, q_pe], dim=-1),
|
||||||
k[:10],
|
k[:kv_len],
|
||||||
v[:10],
|
v[:kv_len],
|
||||||
False,
|
True,
|
||||||
192 ** (-0.5)
|
192 ** (-0.5)
|
||||||
)
|
)
|
||||||
print(attn_ref.shape)
|
print(attn_ref.shape)
|
||||||
|
|
|
@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||||
async for token in interface.inference(input_message,id):
|
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||||
chunk.set_token(token)
|
chunk.set_token(token)
|
||||||
yield chunk
|
yield chunk
|
||||||
return chat_stream_response(request,inner())
|
return chat_stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
||||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
||||||
async for token in interface.inference(input_message,id):
|
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||||
comp.append_token(token)
|
comp.append_token(token)
|
||||||
return comp
|
return comp
|
||||||
|
|
|
@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for token in interface.inference(create.prompt,id):
|
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
d = {'choices':[{'delta':{'content':token}}]}
|
d = {'choices':[{'delta':{'content':token}}]}
|
||||||
yield f"data:{json.dumps(d)}\n\n"
|
yield f"data:{json.dumps(d)}\n\n"
|
||||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||||
|
@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
return stream_response(request,inner())
|
return stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||||
async for token in interface.inference(create.prompt,id):
|
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
comp.append_token(token)
|
comp.append_token(token)
|
||||||
return comp
|
return comp
|
||||||
|
|
|
@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
|
||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
|
from typing import Optional
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
class KTransformersThreadContext(TransformersThreadContext):
|
class KTransformersThreadContext(TransformersThreadContext):
|
||||||
|
@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||||
|
except:
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_length=args.max_new_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.temperature,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
config._attn_implementation = "flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
" belong to current model):"
|
" belong to current model):"
|
||||||
)
|
)
|
||||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||||
|
self.model.generation_config = generation_config
|
||||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
|
@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
||||||
try:
|
|
||||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
|
||||||
except:
|
|
||||||
gen_config = GenerationConfig(
|
|
||||||
max_length=128,
|
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True
|
|
||||||
)
|
|
||||||
self.model.generation_config = gen_config
|
|
||||||
if self.model.generation_config.pad_token_id is None:
|
if self.model.generation_config.pad_token_id is None:
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
@ -128,7 +129,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
if(input_ids_length >= self.args.cache_lens):
|
if(input_ids_length >= self.args.cache_lens):
|
||||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||||
|
@ -206,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
if flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.reset_buffer()
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
|
@ -215,7 +216,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||||
yield v
|
yield v
|
||||||
|
|
|
@ -202,17 +202,20 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.seq_length += 1
|
self.seq_length += 1
|
||||||
return self.streamer.put(new_tokens)
|
return self.streamer.put(new_tokens)
|
||||||
|
|
||||||
def prepare_logits_wrapper(self, inputs, device):
|
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
|
if temperature is None or temperature == 0:
|
||||||
|
temperature = self.model.generation_config.temperature
|
||||||
|
if top_p is None:
|
||||||
|
top_p = self.model.generation_config.top_p
|
||||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||||
None, max_length=self.args.max_new_tokens,
|
None, max_length=self.args.max_new_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=self.args.top_k,
|
top_k=self.args.top_k,
|
||||||
top_p=self.args.top_p,
|
top_p=top_p,
|
||||||
temperature=self.args.temperature,
|
temperature=temperature,
|
||||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||||
)
|
)
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.generation_config = generation_config
|
|
||||||
try: # transformers==4.43
|
try: # transformers==4.43
|
||||||
self.logits_warper = (
|
self.logits_warper = (
|
||||||
self.model._get_logits_warper(generation_config, device=device)
|
self.model._get_logits_warper(generation_config, device=device)
|
||||||
|
@ -255,7 +258,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
|
@ -323,7 +326,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
|
@ -365,7 +368,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
self.streamer.reset()
|
self.streamer.reset()
|
||||||
self.profiler.create_and_start_timer("tokenize")
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
|
@ -392,7 +395,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
print(think, end="",flush=True)
|
print(think, end="",flush=True)
|
||||||
yield think
|
yield think
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||||
# output think token after prefill done
|
# output think token after prefill done
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
|
|
|
@ -25,6 +25,8 @@ class ChatCompletionCreate(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model : str
|
model : str
|
||||||
stream : bool = False
|
stream : bool = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
|
@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: str | List[str]
|
prompt: str | List[str]
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|
|
@ -27,6 +27,7 @@ import torch
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
from .custom_loader import SafeTensorLoader
|
from .custom_loader import SafeTensorLoader
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import math
|
||||||
|
|
||||||
class GGMLQuantizationType(IntEnum):
|
class GGMLQuantizationType(IntEnum):
|
||||||
F32 = 0
|
F32 = 0
|
||||||
|
@ -230,7 +231,7 @@ class GGUFLoader:
|
||||||
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
|
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
|
||||||
ggml_type = read_value(f, DATA_TYPES["uint32"])
|
ggml_type = read_value(f, DATA_TYPES["uint32"])
|
||||||
bad_offset = read_value(f, DATA_TYPES["uint64"])
|
bad_offset = read_value(f, DATA_TYPES["uint64"])
|
||||||
n_elems = int(np.prod(shape))
|
n_elems = int(math.prod(shape))
|
||||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||||
n_bytes = n_elems * type_size // block_size
|
n_bytes = n_elems * type_size // block_size
|
||||||
np_dims = tuple(reversed(shape))
|
np_dims = tuple(reversed(shape))
|
||||||
|
|
|
@ -170,7 +170,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
past_key_values.cur_idx=cache_position
|
past_key_values.cur_idx=cache_position
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
|
||||||
if mode == "long_context":
|
if mode == "long_context":
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||||
else:
|
else:
|
||||||
|
@ -183,8 +182,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||||
generation_config, model_kwargs = model._prepare_generation_config(
|
generation_config, model_kwargs = model._prepare_generation_config(
|
||||||
None, max_length=max_new_tokens,
|
None, do_sample=True
|
||||||
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
|
# change this to modify generate config
|
||||||
|
#top_k=5, top_p=0.85, temperature=0.1
|
||||||
)
|
)
|
||||||
try: # transformers==4.43
|
try: # transformers==4.43
|
||||||
logits_warper = (
|
logits_warper = (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue