diff --git a/.github/ISSUE_TEMPLATE/-bug-.yaml b/.github/ISSUE_TEMPLATE/-bug-.yaml new file mode 100644 index 0000000..7c74c6e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/-bug-.yaml @@ -0,0 +1,39 @@ +name: 🐞 Bug report +description: Create a report to help us reproduce and fix the bug +title: "[Bug] " +labels: ['Bug'] + +body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. I have searched related issues but cannot get the expected help. + - label: 2. The bug has not been fixed in the latest version. + - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. + - label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed. + - label: 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed. + +- type: textarea + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Reproduction + description: | + What command or script did you run? Which **model** are you using? + placeholder: | + A placeholder for the command. + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please provide necessary environment information here (e.g. OS/GPU/CPU). Otherwise the issue will be close. + placeholder: Environment here. + validations: + required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/-bug2-.yaml b/.github/ISSUE_TEMPLATE/-bug2-.yaml new file mode 100644 index 0000000..85b2180 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/-bug2-.yaml @@ -0,0 +1,39 @@ +name: 🐞 BUG报告 +description: 创建报告以帮助我们复现并修复BUG +title: "[Bug] " +labels: ['Bug'] + +body: +- type: checkboxes + attributes: + label: 检查清单 + options: + - label: 1. 我已经搜索过相关问题,但未能获得预期的帮助 + - label: 2. 该问题在最新版本中尚未修复 + - label: 3. 请注意,如果您提交的BUG相关 issue 缺少对应环境信息和最小可复现示例,我们将难以复现和定位问题,降低获得反馈的可能性 + - label: 4. 如果您提出的不是bug而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭 + - label: 5. 为方便社区交流,我将使用中文/英文或附上中文/英文翻译(如使用其他语言)。未附带翻译的非中文/英语内容可能会被关闭 + +- type: textarea + attributes: + label: 问题描述 + description: 清晰简洁地描述BUG是什么 + validations: + required: true +- type: textarea + attributes: + label: 复现步骤 + description: | + 你运行了什么命令或脚本?使用的是哪个**模型**? + placeholder: | + 在此处填写命令 + validations: + required: true +- type: textarea + attributes: + label: 环境信息 + description: | + 请提供必要的环境信息(如操作系统/GPU/CPU),否则该 issue 将被关闭 + placeholder: 在此处填写环境信息 + validations: + required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/-feature-.yaml b/.github/ISSUE_TEMPLATE/-feature-.yaml new file mode 100644 index 0000000..4ef23c4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/-feature-.yaml @@ -0,0 +1,23 @@ +name: 🚀 Feature request +description: Suggest an idea for this project +title: "[Feature] " + +body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed. + - label: 2. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-English/Chinese content without translation may be closed. +- type: textarea + attributes: + label: Motivation + description: | + A clear and concise description of the motivation of the feature. + validations: + required: true +- type: textarea + attributes: + label: Related resources + description: | + If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/-feature2-.yaml b/.github/ISSUE_TEMPLATE/-feature2-.yaml new file mode 100644 index 0000000..571af4a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/-feature2-.yaml @@ -0,0 +1,23 @@ +name: 🚀 新功能请求 +description: 为项目提出新功能建议 +title: "[Feature] " + +body: +- type: checkboxes + attributes: + label: 检查清单 + options: + - label: 1. 如果您提出的不是新功能而是问题,请在讨论区发起讨论 https://github.com/kvcache-ai/ktransformers/discussions。否则该 issue 将被关闭 + - label: 2. 为方便社区交流,我将使用中文/英文或附上英文/中文翻译(如使用其他语言)。未附带翻译的非英文/中文内容可能会被关闭 +- type: textarea + attributes: + label: 需求背景 + description: | + 清晰简洁地描述该功能的背景需求 + validations: + required: true +- type: textarea + attributes: + label: 相关资源 + description: | + 如果有官方代码实现或第三方实现,请在此提供相关信息,这将非常有帮助 \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index ecce9b7..22623a5 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -209,6 +209,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) if (WIN32) include_directories("$ENV{CUDA_PATH}/include") + add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) elseif (UNIX) if (KTRANSFORMERS_USE_CUDA) find_package(CUDA REQUIRED) diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp index c59cb94..4190c03 100644 --- a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp +++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp @@ -10,6 +10,8 @@ #include "kvcache.h" +#include + void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output, float *attn_lse, int batch_size, Backend *backend) { diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp index eadf90f..4de217f 100644 --- a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp +++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp @@ -9,6 +9,9 @@ **/ #include "kvcache.h" + +#include + void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp index 998f1b0..0104905 100644 --- a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp +++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp @@ -10,6 +10,8 @@ #include "kvcache.h" +#include + void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx, Backend *backend) { // Timer start diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp index f1d6f7d..c57d475 100644 --- a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp +++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp @@ -10,6 +10,8 @@ #include "kvcache.h" +#include + std::string ggml_type_to_string(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 7cbac7c..c6c9c2e 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -110,15 +110,15 @@ def local_chat( optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) try: - model.generation_config = GenerationConfig.from_pretrained(model_path) - except: - gen_config = GenerationConfig( - max_length=128, - temperature=0.7, - top_p=0.9, - do_sample=True - ) - model.generation_config = gen_config + model.generation_config = GenerationConfig.from_pretrained(model_path) + except Exception as e: + print(f"generation config can't auto create, make default. Message: {e}") + gen_config = GenerationConfig( + temperature=0.6, + top_p=0.95, + do_sample=True + ) + model.generation_config = gen_config # model.generation_config = GenerationConfig.from_pretrained(model_path) if model.generation_config.pad_token_id is None: model.generation_config.pad_token_id = model.generation_config.eos_token_id diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 35c8093..25b1359 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - def forward_linux_flashinfer( + def forward_linux_flashinfer_chunk( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + def forward_linux_flashinfer( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + if q_len <= self.chunck_size or not self.absorb_for_prefill: + return self.forward_linux_flashinfer_chunk( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + + assert False + + def forward_windows( self, hidden_states: torch.Tensor, diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index f8ea3ce..2bec5cc 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -122,7 +122,7 @@ class MLAWrapper(): if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf - + self.wrapper.plan( qo_indptr, kv_indptr, @@ -139,6 +139,11 @@ class MLAWrapper(): ) def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): + #print("run") + #print(self.wrapper._qo_indptr_buf) + #print(self.wrapper._kv_indptr_buf) + #print(self.wrapper._kv_indices_buf) + #print(self.wrapper._kv_len_arr_buf) return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) class MLAWrapperSingleton(): @@ -201,11 +206,12 @@ class MLAWrapperSingleton(): if __name__ == "__main__": max_batch_size = 1 - max_pages = 1 + max_pages = 128 page_size = 64 num_heads = 128 - q_len = 10 + kv_len = 2069 + q_len = 1 q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") @@ -218,7 +224,7 @@ if __name__ == "__main__": max_pages, ) - kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") + kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") wrapper.plan( qo_indptr, @@ -244,15 +250,15 @@ if __name__ == "__main__": ) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) - print(k[:10].shape) - print(v[:10].shape) + print(k[:kv_len].shape) + print(v[:kv_len].shape) attn_ref, lse_ref = attention_ref( max_batch_size, torch.cat([q_nope, q_pe], dim=-1), - k[:10], - v[:10], - False, + k[:kv_len], + v[:kv_len], + True, 192 ** (-0.5) ) print(attn_ref.shape) diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index 4e91279..dd7185f 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): if create.stream: async def inner(): chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) - async for token in interface.inference(input_message,id): + async for token in interface.inference(input_message,id,create.temperature,create.top_p): chunk.set_token(token) yield chunk return chat_stream_response(request,inner()) else: comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) - async for token in interface.inference(input_message,id): + async for token in interface.inference(input_message,id,create.temperature,create.top_p): comp.append_token(token) return comp diff --git a/ktransformers/server/api/openai/legacy/completions.py b/ktransformers/server/api/openai/legacy/completions.py index be85a29..fe250f4 100644 --- a/ktransformers/server/api/openai/legacy/completions.py +++ b/ktransformers/server/api/openai/legacy/completions.py @@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): if create.stream: async def inner(): - async for token in interface.inference(create.prompt,id): + async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): d = {'choices':[{'delta':{'content':token}}]} yield f"data:{json.dumps(d)}\n\n" d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} @@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate): return stream_response(request,inner()) else: comp = CompletionObject(id=id,object='text_completion',created=int(time())) - async for token in interface.inference(create.prompt,id): + async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): comp.append_token(token) return comp diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index ce9cb71..4201c20 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device +from typing import Optional from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton - warm_uped = False class KTransformersThreadContext(TransformersThreadContext): @@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface): torch.set_grad_enabled(False) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) + try: + generation_config = GenerationConfig.from_pretrained(args.model_dir) + except: + generation_config = GenerationConfig( + max_length=args.max_new_tokens, + temperature=args.temperature, + top_p=args.temperature, + do_sample=True + ) + torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface): " belong to current model):" ) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) - + self.model.generation_config = generation_config self.device_map = self.model.gguf_loader.tensor_device_map # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") self.cache = StaticCache( @@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface): dtype=self.model.dtype, ) # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") - try: - self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) - except: - gen_config = GenerationConfig( - max_length=128, - temperature=0.7, - top_p=0.9, - do_sample=True - ) - self.model.generation_config = gen_config + if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) @@ -128,7 +129,7 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): input_ids_length = input_ids.shape[-1] if(input_ids_length >= self.args.cache_lens): logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") @@ -206,7 +207,7 @@ class KTransformersInterface(TransformersInterface): if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() - self.prepare_logits_wrapper(input_ids, device) + self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -215,7 +216,7 @@ class KTransformersInterface(TransformersInterface): device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") return torch.tensor([self.seq_length - 1], device=device) - async def inference(self, local_messages, thread_id: str): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]): async with self._infer_lock: - async for v in super().inference(local_messages, thread_id): + async for v in super().inference(local_messages, thread_id, temperature, top_p): yield v diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index e6d444e..9bc3117 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -202,20 +202,23 @@ class TransformersInterface(BackendInterfaceBase): self.seq_length += 1 return self.streamer.put(new_tokens) - def prepare_logits_wrapper(self, inputs, device): + def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None): + if temperature is None or temperature == 0: + temperature = self.model.generation_config.temperature + if top_p is None: + top_p = self.model.generation_config.top_p generation_config, model_kwargs = self.model._prepare_generation_config( None, max_length=self.args.max_new_tokens, do_sample=True, top_k=self.args.top_k, - top_p=self.args.top_p, - temperature=self.args.temperature, + top_p=top_p, + temperature=temperature, repetition_penalty=self.args.repetition_penalty # change this to modify generate config ) self.inputs = inputs - self.generation_config = generation_config try: # transformers==4.43 self.logits_warper = ( - self.model._get_logits_warper(generation_config,device=device) + self.model._get_logits_warper(generation_config, device=device) ) except: self.logits_warper = ( @@ -255,7 +258,7 @@ class TransformersInterface(BackendInterfaceBase): return self.logits_to_token(logits) @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None): input_ids_length = input_ids.shape[-1] logger.debug(f"input_ids: {input_ids.shape}") @@ -323,7 +326,7 @@ class TransformersInterface(BackendInterfaceBase): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - self.prepare_logits_wrapper(input_ids, device) + self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -365,7 +368,7 @@ class TransformersInterface(BackendInterfaceBase): self.last_request_id = thread_id return True - async def inference(self, local_messages, thread_id: str): + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): self.streamer.reset() self.profiler.create_and_start_timer("tokenize") if isinstance(local_messages, List): @@ -392,7 +395,7 @@ class TransformersInterface(BackendInterfaceBase): print(think, end="",flush=True) yield think - for t in self.prefill(input_ids, self.check_is_new(thread_id)): + for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): # output think token after prefill done if t is not None: print(t, end="",flush=True) diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index 19505e2..e5d8f95 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel): messages: List[Message] model : str stream : bool = False - + temperature: Optional[float] = None + top_p: Optional[float] = None + def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/schemas/legacy/completions.py b/ktransformers/server/schemas/legacy/completions.py index 874e556..ea936ea 100644 --- a/ktransformers/server/schemas/legacy/completions.py +++ b/ktransformers/server/schemas/legacy/completions.py @@ -9,6 +9,8 @@ class CompletionCreate(BaseModel): model: str prompt: str | List[str] stream: bool = False + temperature: Optional[float] = None + top_p: Optional[float] = None def get_tokenizer_messages(self): if isinstance(self.prompt,List): diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 86051b5..84ada15 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -27,6 +27,7 @@ import torch import KTransformersOps from .custom_loader import SafeTensorLoader import ctypes +import math class GGMLQuantizationType(IntEnum): F32 = 0 @@ -230,7 +231,7 @@ class GGUFLoader: shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)] ggml_type = read_value(f, DATA_TYPES["uint32"]) bad_offset = read_value(f, DATA_TYPES["uint64"]) - n_elems = int(np.prod(shape)) + n_elems = int(math.prod(shape)) block_size, type_size = GGML_QUANT_SIZES[ggml_type] n_bytes = n_elems * type_size // block_size np_dims = tuple(reversed(shape)) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 87bbd2b..f030257 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -170,7 +170,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud past_key_values.cur_idx=cache_position start_time = time.time() - inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) if mode == "long_context": inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) else: @@ -183,8 +182,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) generation_config, model_kwargs = model._prepare_generation_config( - None, max_length=max_new_tokens, - do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config + None, do_sample=True + # change this to modify generate config + #top_k=5, top_p=0.85, temperature=0.1 ) try: # transformers==4.43 logits_warper = (