wjh-change

This commit is contained in:
anyanqilin 2024-10-29 03:35:05 +00:00 committed by liam
parent 7c94df4bcf
commit 2d67016d14
4 changed files with 74 additions and 26 deletions

View file

@ -0,0 +1,56 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -9,19 +9,19 @@ class ArgumentParser:
def parse_args(self):
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--port", type=int, default=8082)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str, default=self.cfg.model_dir)
parser.add_argument("--web", type=bool, default=True)
parser.add_argument("--model_name", type=str, default='DeepSeek-V2-Lite-Chat')
parser.add_argument("--model_dir", type=str, default='/mnt/data/model/DeepSeek-V2-Lite-Chat')
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--gguf_path", type=str, default='/mnt/data/model/DeepSeek-V2-Lite-Chat-GGUF')
parser.add_argument("--optimize_config_path", default='/mnt/data/benchmark/ktransformers-dev/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml', type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--type", type=str, default='ktransformers')
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
@ -69,7 +69,7 @@ class ArgumentParser:
parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings)
parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia)
parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size)
parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens)
parser.add_argument("--cache_lens", type=int, default='32768')
# log configs
# log level: debug, info, warn, error, crit

View file

@ -164,6 +164,7 @@ class TransformersInterface(BackendInterfaceBase):
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
@ -172,25 +173,12 @@ class TransformersInterface(BackendInterfaceBase):
else:
new_messages.append(m)
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
# logger.debug(f"last message: {new_messages[-1]}")
# input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device)
# else:
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
# We can only hope that the input_ids are the same
unequal_mask = torch.ne(x,y)
unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item()
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:]
input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
else:
input_ids = self.tokenizer.apply_chat_template(
new_messages, return_tensors="pt", add_generation_prompt=True
).to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids

View file

@ -4,6 +4,10 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import uvicorn
import sys
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config