mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
Initial commit
This commit is contained in:
commit
18c42e67df
247 changed files with 53775 additions and 0 deletions
115
ktransformers/local_chat.py
Normal file
115
ktransformers/local_chat.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2024 Shaoyuan Chen
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, project_dir)
|
||||
import torch
|
||||
import logging
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
TextStreamer,
|
||||
)
|
||||
import json
|
||||
import fire
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
}
|
||||
|
||||
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
default_optimize_rules ={
|
||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||
}
|
||||
|
||||
def local_chat(
|
||||
model_path: str,
|
||||
optimize_rule_path: str = None,
|
||||
gguf_path: str = None,
|
||||
max_new_tokens: int = 1000,
|
||||
cpu_infer: int = Config().cpu_infer
|
||||
):
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
Config().cpu_infer = cpu_infer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
if optimize_rule_path is None:
|
||||
if config.architectures[0] in default_optimize_rules:
|
||||
print("using default_optimize_rule for", config.architectures[0])
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = input(
|
||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||
)
|
||||
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
if model.generation_config.pad_token_id is None:
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
model.eval()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
system = platform.system()
|
||||
if (system == u'Windows'):
|
||||
os.system('cls')
|
||||
else:
|
||||
os.system('clear')
|
||||
|
||||
while True:
|
||||
content = input("Chat: ")
|
||||
# if content is num
|
||||
if content == "":
|
||||
content = "Please write a piece of quicksort code in C++."
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens)
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(local_chat)
|
Loading…
Add table
Add a link
Reference in a new issue