mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
✨: refactor local_chat and fix message slice bug in server
This commit is contained in:
parent
43fc7f44a6
commit
dd1d8667f3
13 changed files with 549 additions and 405 deletions
|
@ -1,33 +1,23 @@
|
|||
"""
|
||||
Description :
|
||||
Description :
|
||||
Author : Boxin Zhang, Azure-Tang
|
||||
Version : 0.1.0
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, project_dir)
|
||||
import torch
|
||||
import logging
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
TextStreamer,
|
||||
)
|
||||
import json
|
||||
import fire
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
custom_models = {
|
||||
|
@ -37,9 +27,7 @@ custom_models = {
|
|||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
)
|
||||
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
default_optimize_rules = {
|
||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||
|
@ -48,75 +36,28 @@ default_optimize_rules = {
|
|||
}
|
||||
|
||||
|
||||
def local_chat(
|
||||
model_path: str | None = None,
|
||||
optimize_rule_path: str = None,
|
||||
gguf_path: str | None = None,
|
||||
max_new_tokens: int = 1000,
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
use_cuda_graph: bool = True,
|
||||
prompt_file : str | None = None,
|
||||
mode: str = "normal",
|
||||
):
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
Config().cpu_infer = cpu_infer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
if mode == 'long_context':
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
def local_chat():
|
||||
config = Config()
|
||||
arg_parser = ArgumentParser(config)
|
||||
# 初始化消息
|
||||
arg_parser.parse_args()
|
||||
if config.backend_type == "transformers":
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||
elif config.backend_type == "exllamav2":
|
||||
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
|
||||
elif config.backend_type == "ktransformers":
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if (
|
||||
"Qwen2Moe" in config.architectures[0]
|
||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Llama" in config.architectures[0]:
|
||||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
if optimize_rule_path is None:
|
||||
if config.architectures[0] in default_optimize_rules:
|
||||
print("using default_optimize_rule for", config.architectures[0])
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = input(
|
||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||
)
|
||||
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
if model.generation_config.pad_token_id is None:
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
model.eval()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
raise NotImplementedError(f"{config.backend_type} not implemented")
|
||||
interface = BackendInterface(config)
|
||||
|
||||
system = platform.system()
|
||||
if system == "Windows":
|
||||
os.system("cls")
|
||||
else:
|
||||
os.system("clear")
|
||||
|
||||
# add a history chat content
|
||||
his_content = []
|
||||
while True:
|
||||
content = input("Chat: ")
|
||||
if content.startswith('"""'): # prefix """
|
||||
|
@ -132,28 +73,27 @@ def local_chat(
|
|||
break
|
||||
else:
|
||||
content += line + "\n"
|
||||
|
||||
if content == "":
|
||||
if prompt_file != None:
|
||||
content = open(prompt_file, "r").read()
|
||||
else:
|
||||
if config.prompt_file == None or config.prompt_file == "":
|
||||
content = "Please write a piece of quicksort code in C++."
|
||||
else:
|
||||
content = open(config.prompt_file, "r").read()
|
||||
elif os.path.isfile(content):
|
||||
content = open(content, "r").read()
|
||||
messages = [{"role": "user", "content": content}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if mode == 'long_context':
|
||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
torch.set_default_dtype(
|
||||
torch.bfloat16
|
||||
) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
|
||||
)
|
||||
messages = his_content + [{"role": "user", "content": content}]
|
||||
|
||||
async def async_inference(messages):
|
||||
generated = ""
|
||||
async for token in interface.inference(messages, "local_chat"):
|
||||
generated += token
|
||||
return generated
|
||||
|
||||
generated = asyncio.run(async_inference(messages))
|
||||
his_content += [
|
||||
{"role": "user", "content": content},
|
||||
{"role": "assitant", "content": generated},
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(local_chat)
|
||||
local_chat()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue