mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
13
ktransformers/local_chat.py
Normal file → Executable file
13
ktransformers/local_chat.py
Normal file → Executable file
|
@ -31,18 +31,21 @@ import fire
|
|||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
|
||||
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
default_optimize_rules ={
|
||||
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
|
||||
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
|
||||
}
|
||||
|
||||
def local_chat(
|
||||
|
@ -50,7 +53,8 @@ def local_chat(
|
|||
optimize_rule_path: str = None,
|
||||
gguf_path: str = None,
|
||||
max_new_tokens: int = 1000,
|
||||
cpu_infer: int = Config().cpu_infer
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
use_cuda_graph: bool = True,
|
||||
):
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
@ -64,6 +68,8 @@ def local_chat(
|
|||
print("using custom modeling_xxx.py.")
|
||||
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
|
@ -100,7 +106,6 @@ def local_chat(
|
|||
|
||||
while True:
|
||||
content = input("Chat: ")
|
||||
# if content is num
|
||||
if content == "":
|
||||
content = "Please write a piece of quicksort code in C++."
|
||||
|
||||
|
@ -109,7 +114,7 @@ def local_chat(
|
|||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens)
|
||||
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(local_chat)
|
||||
fire.Fire(local_chat)
|
Loading…
Add table
Add a link
Reference in a new issue