This commit is contained in:
liam 2025-03-01 00:12:21 +08:00
commit 80e0536fb0
20 changed files with 231 additions and 53 deletions

39
.github/ISSUE_TEMPLATE/-bug-.yaml vendored Normal file
View file

@ -0,0 +1,39 @@
name: 🐞 Bug report
description: Create a report to help us reproduce and fix the bug
title: "[Bug] "
labels: ['Bug']
body:
- type: checkboxes
attributes:
label: Checklist
options:
- label: 1. I have searched related issues but cannot get the expected help.
- label: 2. The bug has not been fixed in the latest version.
- label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
- label: 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.
- type: textarea
attributes:
label: Describe the bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Reproduction
description: |
What command or script did you run? Which **model** are you using?
placeholder: |
A placeholder for the command.
validations:
required: true
- type: textarea
attributes:
label: Environment
description: |
Please provide necessary environment information here (e.g. OS/GPU/CPU). Otherwise the issue will be close.
placeholder: Environment here.
validations:
required: true

39
.github/ISSUE_TEMPLATE/-bug2-.yaml vendored Normal file
View file

@ -0,0 +1,39 @@
name: 🐞 BUG报告
description: 创建报告以帮助我们复现并修复BUG
title: "[Bug] "
labels: ['Bug']
body:
- type: checkboxes
attributes:
label: 检查清单
options:
- label: 1. 我已经搜索过相关问题,但未能获得预期的帮助
- label: 2. 该问题在最新版本中尚未修复
- label: 3. 请注意如果您提交的BUG相关 issue 缺少对应环境信息和最小可复现示例,我们将难以复现和定位问题,降低获得反馈的可能性
- label: 4. 如果您提出的不是bug而是问题请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
- label: 5. 为方便社区交流,我将使用中文/英文或附上中文/英文翻译(如使用其他语言)。未附带翻译的非中文/英语内容可能会被关闭
- type: textarea
attributes:
label: 问题描述
description: 清晰简洁地描述BUG是什么
validations:
required: true
- type: textarea
attributes:
label: 复现步骤
description: |
你运行了什么命令或脚本?使用的是哪个**模型**
placeholder: |
在此处填写命令
validations:
required: true
- type: textarea
attributes:
label: 环境信息
description: |
请提供必要的环境信息(如操作系统/GPU/CPU否则该 issue 将被关闭
placeholder: 在此处填写环境信息
validations:
required: true

23
.github/ISSUE_TEMPLATE/-feature-.yaml vendored Normal file
View file

@ -0,0 +1,23 @@
name: 🚀 Feature request
description: Suggest an idea for this project
title: "[Feature] "
body:
- type: checkboxes
attributes:
label: Checklist
options:
- label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
- label: 2. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-English/Chinese content without translation may be closed.
- type: textarea
attributes:
label: Motivation
description: |
A clear and concise description of the motivation of the feature.
validations:
required: true
- type: textarea
attributes:
label: Related resources
description: |
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.

23
.github/ISSUE_TEMPLATE/-feature2-.yaml vendored Normal file
View file

@ -0,0 +1,23 @@
name: 🚀 新功能请求
description: 为项目提出新功能建议
title: "[Feature] "
body:
- type: checkboxes
attributes:
label: 检查清单
options:
- label: 1. 如果您提出的不是新功能而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭
- label: 2. 为方便社区交流,我将使用中文/英文或附上英文/中文翻译(如使用其他语言)。未附带翻译的非英文/中文内容可能会被关闭
- type: textarea
attributes:
label: 需求背景
description: |
清晰简洁地描述该功能的背景需求
validations:
required: true
- type: textarea
attributes:
label: 相关资源
description: |
如果有官方代码实现或第三方实现,请在此提供相关信息,这将非常有帮助

View file

@ -209,6 +209,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32) if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (KTRANSFORMERS_USE_CUDA) if (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED) find_package(CUDA REQUIRED)

View file

@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output, void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size, float *attn_lse, int batch_size,
Backend *backend) { Backend *backend) {

View file

@ -9,6 +9,9 @@
**/ **/
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) { void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
// Timer start // Timer start
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();

View file

@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) { int block_idx, Backend *backend) {
// Timer start // Timer start

View file

@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
std::string ggml_type_to_string(ggml_type type) { std::string ggml_type_to_string(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:

View file

@ -110,15 +110,15 @@ def local_chat(
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try: try:
model.generation_config = GenerationConfig.from_pretrained(model_path) model.generation_config = GenerationConfig.from_pretrained(model_path)
except: except Exception as e:
gen_config = GenerationConfig( print(f"generation config can't auto create, make default. Message: {e}")
max_length=128, gen_config = GenerationConfig(
temperature=0.7, temperature=0.6,
top_p=0.9, top_p=0.95,
do_sample=True do_sample=True
) )
model.generation_config = gen_config model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path) # model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None: if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id

View file

@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux_flashinfer( def forward_linux_flashinfer_chunk(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size or not self.absorb_for_prefill:
return self.forward_linux_flashinfer_chunk(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
assert False
def forward_windows( def forward_windows(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View file

@ -139,6 +139,11 @@ class MLAWrapper():
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton(): class MLAWrapperSingleton():
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
if __name__ == "__main__": if __name__ == "__main__":
max_batch_size = 1 max_batch_size = 1
max_pages = 1 max_pages = 128
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
q_len = 10 kv_len = 2069
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
@ -218,7 +224,7 @@ if __name__ == "__main__":
max_pages, max_pages,
) )
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan( wrapper.plan(
qo_indptr, qo_indptr,
@ -244,15 +250,15 @@ if __name__ == "__main__":
) )
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:10].shape) print(k[:kv_len].shape)
print(v[:10].shape) print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref( attn_ref, lse_ref = attention_ref(
max_batch_size, max_batch_size,
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
k[:10], k[:kv_len],
v[:10], v[:kv_len],
False, True,
192 ** (-0.5) 192 ** (-0.5)
) )
print(attn_ref.shape) print(attn_ref.shape)

View file

@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token) chunk.set_token(token)
yield chunk yield chunk
return chat_stream_response(request,inner()) return chat_stream_response(request,inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id): async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp

View file

@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id): async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token) comp.append_token(token)
return comp return comp

View file

@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False warm_uped = False
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.temperature,
do_sample=True
)
torch.set_default_dtype(config.torch_dtype) torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):" " belong to current model):"
) )
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self.cache = StaticCache( self.cache = StaticCache(
@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype=self.model.dtype, dtype=self.model.dtype,
) )
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try:
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
self.model.generation_config = gen_config
if self.model.generation_config.pad_token_id is None: if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
@ -128,7 +129,7 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens): if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
@ -206,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
if flashinfer_enabled: if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer() MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@ -215,7 +216,7 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str): async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id): async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v yield v

View file

@ -202,20 +202,23 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1 self.seq_length += 1
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def prepare_logits_wrapper(self, inputs, device): def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
if temperature is None or temperature == 0:
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
generation_config, model_kwargs = self.model._prepare_generation_config( generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens, None, max_length=self.args.max_new_tokens,
do_sample=True, do_sample=True,
top_k=self.args.top_k, top_k=self.args.top_k,
top_p=self.args.top_p, top_p=top_p,
temperature=self.args.temperature, temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config repetition_penalty=self.args.repetition_penalty # change this to modify generate config
) )
self.inputs = inputs self.inputs = inputs
self.generation_config = generation_config
try: # transformers==4.43 try: # transformers==4.43
self.logits_warper = ( self.logits_warper = (
self.model._get_logits_warper(generation_config,device=device) self.model._get_logits_warper(generation_config, device=device)
) )
except: except:
self.logits_warper = ( self.logits_warper = (
@ -255,7 +258,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
@ -323,7 +326,7 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@ -365,7 +368,7 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List): if isinstance(local_messages, List):
@ -392,7 +395,7 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True) print(think, end="",flush=True)
yield think yield think
for t in self.prefill(input_ids, self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done # output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)

View file

@ -25,6 +25,8 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model : str model : str
stream : bool = False stream : bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]

View file

@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model: str model: str
prompt: str | List[str] prompt: str | List[str]
stream: bool = False stream: bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):

View file

@ -27,6 +27,7 @@ import torch
import KTransformersOps import KTransformersOps
from .custom_loader import SafeTensorLoader from .custom_loader import SafeTensorLoader
import ctypes import ctypes
import math
class GGMLQuantizationType(IntEnum): class GGMLQuantizationType(IntEnum):
F32 = 0 F32 = 0
@ -230,7 +231,7 @@ class GGUFLoader:
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)] shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
ggml_type = read_value(f, DATA_TYPES["uint32"]) ggml_type = read_value(f, DATA_TYPES["uint32"])
bad_offset = read_value(f, DATA_TYPES["uint64"]) bad_offset = read_value(f, DATA_TYPES["uint64"])
n_elems = int(np.prod(shape)) n_elems = int(math.prod(shape))
block_size, type_size = GGML_QUANT_SIZES[ggml_type] block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size n_bytes = n_elems * type_size // block_size
np_dims = tuple(reversed(shape)) np_dims = tuple(reversed(shape))

View file

@ -170,7 +170,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
past_key_values.cur_idx=cache_position past_key_values.cur_idx=cache_position
start_time = time.time() start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if mode == "long_context": if mode == "long_context":
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else: else:
@ -183,8 +182,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device) )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
generation_config, model_kwargs = model._prepare_generation_config( generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens, None, do_sample=True
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config # change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
) )
try: # transformers==4.43 try: # transformers==4.43
logits_warper = ( logits_warper = (