mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
wjh-change
This commit is contained in:
parent
7c94df4bcf
commit
2d67016d14
4 changed files with 74 additions and 26 deletions
|
@ -0,0 +1,56 @@
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
|
@ -9,19 +9,19 @@ class ArgumentParser:
|
||||||
def parse_args(self):
|
def parse_args(self):
|
||||||
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
|
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
|
||||||
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
|
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
|
||||||
parser.add_argument("--port", type=int, default=self.cfg.server_port)
|
parser.add_argument("--port", type=int, default=8082)
|
||||||
parser.add_argument("--ssl_keyfile", type=str)
|
parser.add_argument("--ssl_keyfile", type=str)
|
||||||
parser.add_argument("--ssl_certfile", type=str)
|
parser.add_argument("--ssl_certfile", type=str)
|
||||||
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
|
parser.add_argument("--web", type=bool, default=True)
|
||||||
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
|
parser.add_argument("--model_name", type=str, default='DeepSeek-V2-Lite-Chat')
|
||||||
parser.add_argument("--model_dir", type=str, default=self.cfg.model_dir)
|
parser.add_argument("--model_dir", type=str, default='/mnt/data/model/DeepSeek-V2-Lite-Chat')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||||
)
|
)
|
||||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
parser.add_argument("--gguf_path", type=str, default='/mnt/data/model/DeepSeek-V2-Lite-Chat-GGUF')
|
||||||
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
|
parser.add_argument("--optimize_config_path", default='/mnt/data/benchmark/ktransformers-dev/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml', type=str, required=False)
|
||||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||||
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
|
parser.add_argument("--type", type=str, default='ktransformers')
|
||||||
|
|
||||||
# model configs
|
# model configs
|
||||||
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
|
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
|
||||||
|
@ -69,7 +69,7 @@ class ArgumentParser:
|
||||||
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
|
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
|
||||||
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
|
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
|
||||||
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
|
||||||
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
|
parser.add_argument("--cache_lens", type=int, default='32768')
|
||||||
|
|
||||||
# log configs
|
# log configs
|
||||||
# log level: debug, info, warn, error, crit
|
# log level: debug, info, warn, error, crit
|
||||||
|
|
|
@ -164,6 +164,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
if m["role"] == "system":
|
if m["role"] == "system":
|
||||||
logger.warning(f'change {m["role"]} to user')
|
logger.warning(f'change {m["role"]} to user')
|
||||||
m["role"] = "user"
|
m["role"] = "user"
|
||||||
|
|
||||||
new_messages = [messages[0]]
|
new_messages = [messages[0]]
|
||||||
for m in messages[1:]:
|
for m in messages[1:]:
|
||||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||||
|
@ -172,25 +173,12 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
else:
|
else:
|
||||||
new_messages.append(m)
|
new_messages.append(m)
|
||||||
|
|
||||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
|
||||||
# logger.debug(f"last message: {new_messages[-1]}")
|
|
||||||
# input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device)
|
|
||||||
# else:
|
|
||||||
# input_ids = self.tokenizer.apply_chat_template(
|
|
||||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
|
||||||
# ).to(self.args.device)
|
|
||||||
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
|
||||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
x = self.generated_ids[:,:self.seq_length]
|
input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
|
||||||
y = input_ids[:,:self.seq_length]
|
else:
|
||||||
# We can only hope that the input_ids are the same
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
unequal_mask = torch.ne(x,y)
|
new_messages, return_tensors="pt", add_generation_prompt=True
|
||||||
unequal_positions = torch.nonzero(unequal_mask)
|
).to(self.args.device)
|
||||||
num_unequal_elements = unequal_mask.sum().item()
|
|
||||||
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
|
|
||||||
|
|
||||||
input_ids = input_ids[:,self.seq_length:]
|
|
||||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,10 @@ from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import uvicorn.logging
|
import uvicorn.logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import sys
|
||||||
|
|
||||||
|
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
sys.path.insert(0, project_dir)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from ktransformers.server.args import ArgumentParser
|
from ktransformers.server.args import ArgumentParser
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue