support npu

This commit is contained in:
djw 2025-07-21 12:26:14 +00:00
parent dd0e41b3b8
commit 7d51a13c9b
34 changed files with 14004 additions and 5626 deletions

View file

@ -0,0 +1,257 @@
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os
import platform
import sys
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import torch.distributed as dist
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
from ktransformers.util import utils
from ktransformers.models.custom_cache import StaticCache
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
torch.npu.config.allow_internal_format = True
ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "npu/DeepSeek-V3-Chat.yaml",
}
torch.npu.set_compile_mode(jit_compile=False)
import sys, signal, faulthandler
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)
def local_chat(
model_path: str | None = None,
optimize_config_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = False,
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_size: int = utils._MAX_CHUNK_SIZE,
q4_gguf_path: str | None = None,
tp: int = 1,
):
utils.USE_NPU_GRAPH = use_cuda_graph
torch.npu.config.allow_internal_format = False
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
local_rank, world_size = setup_model_parallel(tp=tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if use_cuda_graph:
from ktransformers.util import npu_graph_runner
npu_graph_runner.LAYER_ID = config.num_hidden_layers
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0]) if local_rank == 0 else None
optimize_config_path = default_optimize_rules[config.architectures[0]]
print(f'{optimize_config_path=}') if local_rank == 0 else None
else:
optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, q4_gguf_path=q4_gguf_path)
get_absort_weight(model, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except Exception as e:
print(f"generation config can't auto create, make default. Message: {e}")
gen_config = GenerationConfig(
temperature=0.6,
top_p=0.95,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system()
if system == "Windows":
os.system("cls") if local_rank == 0 else None
else:
os.system("clear") if local_rank == 0 else None
print(f"{model=}") if local_rank == 0 else None
batch_size, seq_length = 1, 1024
device_map = model.gguf_loader.tensor_device_map
static_cache = StaticCache(
config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map,
dtype = model.dtype
)
chunk_size = int(chunk_size)
new_chunk_size = min(max(chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != chunk_size:
chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {chunk_size}.')
torch.distributed.barrier()
while True:
if local_rank == 0:
try:
content = input("Chat: ").strip()
except KeyboardInterrupt:
dist.barrier()
print('Exit all ranks with KeyboardInterrupt!')
sys.exit(0)
if content.startswith('"""'): # prefix """
# multi lines input
content = content[3:] + "\n"
while True:
line = input("")
if line.endswith('"""'):
# end multi lines input
line = line[:-3] # suffix """
if line:
content += line + "\n"
break
else:
content += line + "\n"
if content == "":
if prompt_file != None:
content = open(prompt_file, "r").read()
else:
continue
elif os.path.isfile(content):
f = open(content, "r")
content = f.readlines()
f.close()
else:
content = [f"{len(content)},{max_new_tokens},{content}"]
else:
content = [""]
for line in content:
content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
if world_size > 1:
content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE)
all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_sizes, content_size)
max_content_size = max([size.item() for size in all_content_sizes])
padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
padded_content_tensor[:len(content_tensor)] = content_tensor
all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_tensors, padded_content_tensor)
content_tensor = all_content_tensors[0][:all_content_sizes[0].item()]
line = bytes(content_tensor.cpu().numpy()).decode()
parts = line.split(",")
input_tokens = int(parts[0])
max_new_tokens = int(parts[1])
line = line[line.index(",", line.index(",") + 1) + 1:]
messages = [{"role": "user", "content": line}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
)
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim,
static_cache=static_cache
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
static_cache=static_cache
)
if __name__ == "__main__":
fire.Fire(local_chat)

View file

@ -16,6 +16,16 @@ try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self)
self.max_batch_size = max_batch_size
if use_torch_npu:
self.position = [0]
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
if config.architectures[0] == "DeepseekV3ForCausalLM":
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
if use_torch_npu:
self.page_size = 128
self.page_size_tensor = torch.tensor(
self.page_size,
dtype=torch.int32,
).npu()
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
else:
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
target_device = device
if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
if use_torch_npu:
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
else:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device])
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
self.past_tokens[layer_idx] += cache_position.size(0)
#print(cache_position)
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
if use_torch_npu:
page_idx = cache_position // self.page_size_tensor
page_offset = cache_position % self.page_size_tensor
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
page_idx = page_idx + page_idx_offset.unsqueeze(1)
combined = torch.cat([key_states, value_states], dim=-1)
combined = combined.contiguous()
else:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
if use_torch_npu:
self.position = [0]
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):

View file

@ -27,8 +27,12 @@ try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
try:
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
except:
Warning("triton not found, if you are using npu, ignore this.")
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:

View file

@ -1,5 +1,8 @@
import torch
import flashinfer
try:
import flashinfer
except:
Warning("flashinfer not found, if you are using npu, ignore this.")
import gc
try:
from flash_attn import flash_attn_with_kvcache

View file

@ -5,7 +5,11 @@ Version : 0.2.3
'''
import torch
import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
try:
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
except:
Warning("triton not found, if you are using npu, ignore this.")
flashinfer_enabled = False

View file

@ -14,7 +14,15 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes
import torch
from torch import Tensor, nn
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader

View file

@ -16,6 +16,7 @@ from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
from ktransformers.util.utils import set_module, load_weights
import itertools
import copy
from ktransformers.util import utils
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
for name, child in module._modules.items():
@ -114,7 +115,7 @@ def translate_model_config(model_config: PretrainedConfig):
return model_config
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0", q4_gguf_path=""):
with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
@ -123,15 +124,29 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
model_config = translate_model_config(model_config)
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
if q4_gguf_path:
q4_gguf_loader = GGUFLoader(q4_gguf_path)
utils.Q4_GGUF_LODER = q4_gguf_loader
gguf_loader = GGUFLoader(gguf_path, getattr(model_config, "quantize", None))
with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
else:
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
del_meta(module)
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

View file

@ -0,0 +1,76 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "npu:0"
prefill_device: "npu:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "npu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "npu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "npu"
prefill_device: "npu"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -22,6 +22,10 @@ class ArgumentParser:
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--q4_gguf_path", type=str, default=None)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)

View file

@ -8,6 +8,7 @@ class ConfigArgs(BaseModel):
model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
tp: int = Field(None, description="tp size")
class Config:
protected_namespaces = ()

View file

@ -1,4 +1,19 @@
import torch
from torch import nn
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
from ktransformers.util.utils import get_device, get_all_used_cuda_device
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import os
from typing import Optional, List
import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
@ -19,6 +34,9 @@ from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext):
@ -26,8 +44,15 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args):
self.args = args
def __init__(self, args: ConfigArgs = default_args, input_args=None):
if use_torch_npu:
self.args = input_args
self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
self.args.device = utils.CUR_DEVICE
else:
self.args = args
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
@ -47,7 +72,10 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
if use_torch_npu and input_args.optimize_config_path is not None:
optimize_config_path = input_args.optimize_config_path
elif default_args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
@ -60,7 +88,14 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
if use_torch_npu:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)
#提前absorbed
get_absort_weight(self.model, config)
self.model.eval()
else:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
@ -77,9 +112,92 @@ class KTransformersInterface(TransformersInterface):
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
if use_torch_npu:
self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype=torch.float16, device=self.args.device)
self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype=torch.int32, device=self.args.device)
self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype=torch.float16, device=self.args.device)
self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device=self.args.device)
self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device=self.args.device)
self._infer_lock = asyncio.Lock()
self._infer_lock = asyncio.Lock()
def decode_logits_to_token(self, logits: torch.Tensor):
if self.model.generation_config.do_sample:
logits = logits / self.temperature
torch.manual_seed(0)
probs = logits.view(1, self.model.config.vocab_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = self.next_token_fake
torch_npu._npu_topk_topp_sampling(probs, self.top_k, self.top_p, next_token, self.next_token_probs)
last = next_token.squeeze(-1)
else:
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
probs = torch.nn.functional.softmax(logits, dim=-1)
_, last = torch.topk(probs, k=1, dim=-1)
last = last.item()
self.ever_generated_ids.add(last)
return last
def decode_one_tokens_npu(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner
if check_runner(self.args.device):
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.init(self.args.batch_size, self.seq_length)
self.cuda_graph_runner = npu_graph_runner
utils._USE_NPU_GRAPH = True
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=self.args.device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(self.args.device)
logits = self.cuda_graph_runner(
inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
if self.args.use_cuda_graph:
warm_uped = True
if self.use_static_cache:
logits = self.model(
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(self.current_ids, return_dict=False)[0]
self.cache.change_seq_length(1)
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
def decode_one_tokens(self):
if use_torch_npu:
return self.decode_one_tokens_npu()
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
@ -127,9 +245,145 @@ class KTransformersInterface(TransformersInterface):
return self.logits_to_token(logits)
@torch.no_grad
def prefill_npu(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
device = self.args.device
if is_new:
self.ever_generated_ids.clear()
same_prefix = 0
flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.seq_length = 1
# flat_prev_ids = self.generated_ids.flatten()
# for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
# if flat_input_ids[i] == flat_prev_ids[i]:
# same_prefix += 1
# else:
# break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else:
logger.warning(f"seq_length bigger than cache_lens, killed")
exit(0)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
def chunk_prefill(input_ids, cache_position):
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx = cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1":
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
return self.prefill_npu(self, input_ids, is_new, temperature, top_p, max_tokens, max_completion_tokens)
input_ids_length = input_ids.shape[-1]
if max_tokens is not None:
max_completion_tokens = max_tokens
@ -144,6 +398,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if use_torch_npu:
device = self.args.device
if is_new:
self.ever_generated_ids.clear()
@ -159,16 +415,19 @@ class KTransformersInterface(TransformersInterface):
)
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
if not use_torch_npu:
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
if use_torch_npu:
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
@ -193,6 +452,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
if use_torch_npu:
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface):
@ -248,4 +509,18 @@ class KTransformersInterface(TransformersInterface):
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)
)
def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def run_async():
result = []
async for chunk in self.inference(local_messages, thread_id, temperature, top_p):
pass
return ""
return loop.run_until_complete(run_async())
finally:
loop.close()

View file

@ -32,6 +32,20 @@ from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import torch.distributed as dist
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:
@ -191,11 +205,19 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
if not use_torch_npu:
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
else:
logger.debug(f"new_messages: {new_messages}")
input_ids = self.tokenizer.apply_chat_template(
new_messages, add_generation_prompt=True, return_tensors="pt"
)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
@ -212,6 +234,8 @@ class TransformersInterface(BackendInterfaceBase):
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length += 1
if use_torch_npu:
self.cache.position[0] = self.seq_length
return self.streamer.put(new_tokens)
@staticmethod
@ -273,14 +297,21 @@ class TransformersInterface(BackendInterfaceBase):
top_p = self.model.generation_config.top_p
if top_p == 0:
top_p = 0.0001
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
if use_torch_npu:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, do_sample=True,
top_p=top_p, temperature=temperature
)
else:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.logits_warper = self.tf_logits_warper(generation_config)
@ -372,7 +403,10 @@ class TransformersInterface(BackendInterfaceBase):
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
device = input_ids.device
if use_torch_npu:
device = self.args.device
else:
device = input_ids.device
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
@ -420,7 +454,12 @@ class TransformersInterface(BackendInterfaceBase):
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
if use_torch_npu and self.args.use_cuda_graph:
utils._USE_NPU_GRAPH = False
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.destroy()
def check_is_new(self, thread_id: str):
@ -436,7 +475,87 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
rank = torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if tp_size == world_size and tp_size > 1:
torch.distributed.barrier()
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
dist.all_gather(all_input_sizes, input_size)
max_input_size = max([size.item() for size in all_input_sizes])
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
dist.all_gather(all_padded_inputs, padded_input_ids)
original_size = all_input_sizes[0].item()
input_ids = all_padded_inputs[0][:, :original_size]
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
if tp_size == world_size and rank != 0:
pass
else:
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
if tp_size == world_size and rank != 0:
pass
else:
print(t, end="",flush=True)
yield t, finish_reason
if tp_size == world_size and rank != 0:
pass
else:
self.profiler.pause_timer("decode")
self.report_last_time_performance()
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
yield tok
return
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):

View file

@ -9,7 +9,7 @@ project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -17,6 +17,21 @@ from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
import asyncio
from uuid import uuid4
import torch.distributed
import subprocess
import tempfile
import atexit
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
@ -100,6 +115,77 @@ def custom_openapi(app):
return app.openapi_schema
def main_npu():
torch.npu.config.allow_internal_format = False
cfg = Config()
arg_parser = ArgumentParser(cfg)
args = arg_parser.parse_args()
utils.USE_NPU_GRAPH = args.use_cuda_graph
new_chunk_size = min(max(args.chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != args.chunk_size:
args.chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {args.chunk_size}.')
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg, input_args=args)
args.port += torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if tp_size == world_size and tp_size > 1:
if torch.distributed.get_rank() == 0:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
else:
while True:
try:
context = get_thread_context_manager()
id = str(uuid4())
context.interface.sync_inference("", id)
except Exception as e:
print(f"An error occurred: {e}")
finally:
pass
else:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
def main():
cfg = Config()
@ -119,4 +205,7 @@ def main():
)
if __name__ == "__main__":
main()
if use_torch_npu:
main_npu()
else:
main()

View file

@ -16,7 +16,7 @@ from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs):
def create_interface(config: Config, default_args: ConfigArgs, input_args=None):
if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
elif config.backend_type == 'exllamav2':
@ -27,7 +27,12 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else:
raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
if config.backend_type == 'ktransformers':
GlobalInterface.interface = BackendInterface(default_args, input_args)
else:
GlobalInterface.interface = BackendInterface(default_args)
GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)
class GlobalContextManager:

View file

@ -0,0 +1,210 @@
import os
from datetime import timedelta
import torch
try:
import torch_npu
except:
Warning("torch_npu not found, please install torch_npu for NPU support.")
import torch.distributed as dist
_DATA_PARALLEL_SIZE = 0
_TENSOR_PARALLEL_SIZE = 0
_DATA_PARALLEL_GROUP = None
_TENSOR_PARALLEL_RANKS = None
_TENSOR_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
_DATA_PARALLEL_RANKS = None
_GLOBAL_GROUP = None
_LM_HEAD_GROUP = None
def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):
global _DATA_PARALLEL_SIZE
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_RANKS
global _TENSOR_PARALLEL_SIZE
global _TENSOR_PARALLEL_RANKS
global _TENSOR_PARALLEL_GROUP
os.environ["MASTER_ADDR"] = "localhost"
local_rank = int(os.getenv("LOCAL_RANK", '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
torch_npu.npu.set_device(local_rank)
tp_size = tp
dp_size = world_size // tp_size
_DATA_PARALLEL_SIZE = dp_size
_TENSOR_PARALLEL_SIZE = tp_size
torch.set_num_threads(8)
timeout = timedelta(minutes=distributed_timeout_minutes)
print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}")
torch.distributed.init_process_group(
backend='hccl',
world_size=world_size, rank=local_rank
)
print(f"init process group success ------rank is {local_rank}, world_size is {world_size}")
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
for dp_group_id in range(tp_size):
ranks = list(range(dp_group_id, world_size, tp_size))
dp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
if rank in ranks:
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = dp_group
_DATA_PARALLEL_RANKS = ranks
for tp_group_id in range(dp_size):
start_rank = tp_group_id * tp_size
end_rank = (tp_group_id + 1) * tp_size
ranks = list(range(start_rank, end_rank))
tp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
)
if rank in ranks:
global _TENSOR_PARALLEL_GROUP
_TENSOR_PARALLEL_GROUP = tp_group
_TENSOR_PARALLEL_RANKS = ranks
torch.manual_seed(1)
return local_rank, world_size
def get_tensor_parallel_size():
assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set"
return _TENSOR_PARALLEL_SIZE
def get_tensor_parallel_group():
assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized"
return _TENSOR_PARALLEL_GROUP
def get_tensor_parallel_ranks():
assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel ranks is not initialized"
return _TENSOR_PARALLEL_RANKS
def get_data_parallel_size():
assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized"
return _DATA_PARALLEL_SIZE
def get_data_parallel_gloo():
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized"
return _DATA_PARALLEL_GROUP_GLOO
def get_data_parallel_group():
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_data_parallel_ranks():
assert _DATA_PARALLEL_RANKS is not None, "data parallel ranks is not initialized"
return _DATA_PARALLEL_RANKS
def get_global_group():
assert _GLOBAL_GROUP is not None, "global group is not initialized"
return _GLOBAL_GROUP
def get_nccl_options(pg_name, nccl_comm_cfgs):
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
return nccl_options
else:
return None
def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
translate_col_cut_tensors = ["ffn_down", "attn_output"]
translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"]
translate_lm_cut_tensor = ["output"]
tp = get_tensor_parallel_size()
if tp == 1 or weights.shape == torch.Size([1]):
return weights
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank %= tp
assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}"
if any(t in name for t in translate_col_cut_tensors):
if weights.dim() == 1:
return weights
dim = weights.shape[-1]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[:, rank * chunk_size: (rank + 1) * chunk_size]
return output_weights
elif any(t in name for t in translate_row_cut_tensors):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
elif (tp > 1) and (any(t in name for t in translate_lm_cut_tensor)):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=} {world_size=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
else:
return weights
def get_absort_weight(model, config):
local_rank = torch.distributed.get_rank()
tp = get_tensor_parallel_size()
local_rank %= tp
tp_heads = config.num_attention_heads // tp
for i in range(config.num_hidden_layers):
self = model.model.layers[i].self_attn
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(self, "q_absorb", q_absorb)
setattr(self, "out_absorb", out_absorb)
del self.orig_module.kv_b_proj
def allreduce_wrapper(func):
def wrapper(*args, **kwargs):
orig_output = func(*args, **kwargs)
if isinstance(orig_output, tuple):
if get_tensor_parallel_size() > 1:
org_dtype = orig_output[0].dtype
if org_dtype == torch.bfloat16:
dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,
group=get_tensor_parallel_group())
else:
dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
bf_orig_output = orig_output[0].to(dtype=org_dtype)
else:
bf_orig_output = orig_output[0]
else:
bf_orig_output = orig_output[0]
return (bf_orig_output,) + orig_output[1:]
else:
if get_tensor_parallel_size() > 1:
org_dtype = orig_output.dtype
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=torch.float16)
dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=org_dtype)
return orig_output
return wrapper

View file

@ -7,10 +7,19 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
if not use_torch_npu:
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import *
from safetensors.torch import save_file
from abc import ABC, abstractmethod
@ -42,6 +51,7 @@ class SafeTensorLoader(ModelLoader):
tensor_device_map: dict
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
@ -84,6 +94,7 @@ class SafeTensorLoader(ModelLoader):
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
if translate_name_to_gguf(key) in self.tensor_file_map:
key = translate_name_to_gguf(key)
@ -96,6 +107,7 @@ class SafeTensorLoader(ModelLoader):
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def load_experts(self, key: str, device: str="cpu"):
@ -252,20 +264,57 @@ class SafeTensorLoader(ModelLoader):
def has_tensor(self, name: str):
return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
class W8A8SafeTensorLoader(SafeTensorLoader):
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if "deq_scale" in key:
tensor = torch.from_numpy(
np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))
if "input_scale" in key:
tensor = tensor.to(torch.float16)
if "weight_scale" in key or "weight_offset" in key:
if "ffn" in key:
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(torch.float16)
if "input_offset" in key:
tensor = tensor.to(torch.int8)
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
return tensor.to(device)
def load_dequantized_tensor(self, key: str, device: str = "cpu"):
tensor = self.load_tensor(key, device)
return tensor
class GGUFLoader(ModelLoader):
tensor_info: dict
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
def __init__(self, gguf_path: str, quantize: str = None):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
safetensor_loader = SafeTensorLoader(gguf_path)
if quantize == "w8a8_dynamic":
safetensor_loader = W8A8SafeTensorLoader(gguf_path)
else:
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
self.tensor_info = {}
self.gguf_path = gguf_path

View file

@ -0,0 +1,77 @@
import time
import torch
import torch_npu
import sys
import os
from ktransformers.util.utils import USE_NPU_GRAPH
if USE_NPU_GRAPH:
CAPTURE_PLUGIN_PATH = os.environ.get("CAPTURE_PLUGIN_PATH")
if CAPTURE_PLUGIN_PATH is None:
raise RuntimeError("env CAPTURE_PLUGIN_PATH not exist")
sys.path.append(CAPTURE_PLUGIN_PATH)
from libgraph_capture import graph_capture_init
from libgraph_capture import graph_capture_destroy
from libgraph_capture import graph_capture_begin
from libgraph_capture import graph_capture_end
from libgraph_capture import graph_capture_replay
from libgraph_capture import graph_capture_launch_callback
class NpuGraph:
def init(self):
ret = graph_capture_init()
if ret != 0:
exit()
def destroy(self):
ret = graph_capture_destroy()
if ret != 0:
exit()
def capture_begin(
self,
stream,
capture_error_mode="global"):
torch.npu.synchronize()
torch.npu.empty_cache()
ret = graph_capture_begin(stream, capture_error_mode)
if ret != 0:
exit()
def capture_end(
self,
stream):
ret = graph_capture_end(stream)
if ret != 0:
exit()
def replay(
self,
stream):
ret = graph_capture_replay(stream)
if ret != 0:
exit()
def launch_callback(self, func, data, block, stream):
graph_capture_launch_callback(func, data, block, stream)
class graph:
def __init__(
self,
npu_graph: NpuGraph,
pool,
stream,
capture_error_mode: str = "global"):
self.npu_graph = npu_graph
self.stream = stream.npu_stream
def __enter__(self):
self.npu_graph.capture_begin(self.stream)
def __exit__(self, exc_type, exc_val, exc_tb):
self.npu_graph.capture_end(self.stream)

View file

@ -0,0 +1,218 @@
'''
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from typing import Dict
import acl
import torch
import torch_npu
from torch import nn
import ktransformers.util.npu_graph as npu_graph
from ktransformers.util.utils import CUR_DEVICE
class NPUGraphRunner:
def __init__(self, deviceId):
torch.npu.set_compile_mode(jit_compile=False)
self.deviceId = deviceId
self.enable = False
self.debug = False
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self.tid = None
self.past_key_value = None
def init(self, batch_size, seq_length):
self.tmp_g = npu_graph.NpuGraph()
self.graph = torch.npu.NPUGraph()
self.main_stream = torch_npu.npu.Stream(device=self.deviceId)
self.update_stream = torch_npu.npu.Stream(device=self.deviceId)
self.stream = self.main_stream.npu_stream
self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId)
self.context, ret = acl.rt.get_context(self.deviceId)
if ret != 0:
print("get_context failed! ret: " + str(ret))
exit(-1)
self.exit_flag = False
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
self.workspace = None
if self.tid is None:
def process_callback(args_list):
ins = args_list[0]
ret = acl.rt.set_context(ins.context)
if ret != 0:
print("set_context failed! ret: " + str(ret))
exit(-1)
while True:
acl.rt.process_report(1)
if ins.exit_flag:
break
self.tid, ret = acl.util.start_thread(process_callback, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret = acl.rt.subscribe_report(self.tid, self.stream)
if ret != 0:
print("subscribe_report failed!")
exit(-1)
def destroy(self):
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Begin -------------\n', end='')
self.exit_flag = True
ret = acl.rt.unsubscribe_report(self.tid, self.stream)
if ret != 0:
print("unsubscribe_report failed!")
exit(-1)
self.enable = False
ret = acl.util.stop_thread(self.tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
self.tid = None
self.workspace = None
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
del self.graph
self.tmp_g.destroy()
destroy_runner(self.deviceId)
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Finish -------------\n', end='')
def capture(
self,
model,
cur_token,
position_ids,
cache_position,
past_key_values,
main_device,
**kwargs,
) -> None:
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Begin -------------\n', end='')
self.enable = True
self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
self.seq_length = inputs_embeds.size()[1]
self.main_device = main_device
with torch.no_grad():
with torch.npu.graph(self.graph, stream=self.main_stream):
self.logits = model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs)[0]
if past_key_values != None:
past_key_values.change_seq_length(-1)
self.input_buffers = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"cache_position": cache_position,
}
self.output_buffers = {"logits": self.logits}
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Finish -------------\n', end='')
return
def forward(
self,
inputs_embeds,
position_ids,
cache_position,
) -> torch.Tensor:
def ifa_update_sync(param):
with torch.npu.stream(self.update_stream):
for i in range(len(self.handle)):
if self.first_update is False:
q_nope, kvCache, q_pe, kRopeCache, num_heads, \
softmax_scale, layer_idx, attn_output, softmax_lse = self.ifa_param[i]
torch.npu.graph_task_update_begin(self.update_stream, self.handle[i])
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
kvCache,
kvCache,
workspace=self.workspace,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.past_key_value.page_table_list[layer_idx],
block_size=self.past_key_value.page_size,
actual_seq_lengths_kv=self.past_key_value.position,
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(self.update_stream)
self.event[i].record(self.update_stream)
self.ifa_update_tid, ret = acl.util.start_thread(ifa_update_sync, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret1 = acl.rt.memcpy(self.input_buffers["inputs_embeds"].data_ptr(), inputs_embeds.numel() * 2,
inputs_embeds.data_ptr(), inputs_embeds.numel() * 2, 3)
ret2 = acl.rt.memcpy(self.input_buffers["position_ids"].data_ptr(), position_ids.numel() * 8,
position_ids.data_ptr(), position_ids.numel() * 8, 3)
ret3 = acl.rt.memcpy(self.input_buffers["cache_position"].data_ptr(), cache_position.numel() * 8,
cache_position.data_ptr(), cache_position.numel() * 8, 3)
torch_npu.npu.synchronize()
with torch_npu.npu.stream(self.main_stream):
self.graph.replay()
self.first_update = False
ret = acl.util.stop_thread(self.ifa_update_tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
else:
self.ifa_update_tid = None
return self.output_buffers["logits"]
def launch_callback(self, func, data, block, stream):
self.tmp_g.launch_callback(func, data, block, stream)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
runner_dict = dict()
def check_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
return True
else:
return False
def destroy_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is not None:
runner_dict[deviceId] = None
def get_or_create_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
runner = NPUGraphRunner(deviceId)
runner_dict[deviceId] = runner
return runner

View file

@ -31,8 +31,35 @@ if not torch.xpu.is_available():
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
import os
import re
import torch.distributed as dist
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
warm_uped = False
W8A8_ENABLE = False
Q4_GGUF_LODER = None
USE_NPU_GRAPH = None
WARM_UP_SKIP_CNT = [1, 1]
_USE_NPU_GRAPH = False
_MAX_DECODE_PROFILE = 3
CUR_DEVICE = None
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
def get_use_npu_graph():
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
return _USE_NPU_GRAPH
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
@ -50,6 +77,10 @@ def get_free_ports(n: int, continue_prot: list):
return ports
def get_compute_capability(device:torch.device = None):
if use_torch_npu:
return 0
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
@ -97,9 +128,16 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
if use_torch_npu:
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
all_device_list = list(all_device_list)
return all_device_list
# TODO: support NPU
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
@ -109,6 +147,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
key = prefix + name
translated_key = key
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if isinstance(gguf_loader, SafeTensorLoader):
@ -120,7 +159,13 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
if use_torch_npu:
device = "cpu" if "embd" in translated_key else CUR_DEVICE
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
else:
print(f"loading {translated_key} to {device}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
@ -149,6 +194,8 @@ def sync_all_device(all_device_list):
torch.cuda.synchronize(device)
elif "xpu" in device.lower():
torch.xpu.synchronize(device)
elif use_torch_npu:
torch_npu.synchronize(device)
else:
raise RuntimeError("The device {} is not available".format(device))
@ -228,20 +275,68 @@ def tf_logits_warper(generation_config):
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
if use_torch_npu:
vocabulary_size = model.config.vocab_size
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
torch_device = CUR_DEVICE
else:
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = []
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
if use_cuda_graph:
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
# torch.cuda.set_device(torch_device)
torch_npu.npu.set_device(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None:
past_key_values.change_seq_length(1)
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
for device in all_cuda_device:
# torch.cuda.synchronize(device)
torch_npu.npu.synchronize(device)
if generation_config.do_sample:
logits = logits / temperature
torch.manual_seed(0)
probs = logits.view(batch_size, vocabulary_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = next_token_fake
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
next_token = next_token.squeeze(-1)
else:
next_token_scores = logits_warper(inputs, logits[:, -1, :])
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if use_torch_npu:
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
if cuda_graph_runner is None:
use_cuda_graph = False
if use_cuda_graph:
@ -252,6 +347,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
@ -279,6 +376,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all()
@ -288,11 +386,88 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
return logits
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
global warm_uped
global _USE_NPU_GRAPH
if use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
npu_graph_runner.init(batch_size, seq_length)
with torch_npu.npu.stream(npu_graph_runner.main_stream):
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
warm_uped = True
_USE_NPU_GRAPH = True
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
cuda_graph_runner = npu_graph_runner
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
cache_position, past_key_values, logits_warper, generation_config,
use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
npu_graph_runner.destroy()
_USE_NPU_GRAPH = False
else:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
with torch.no_grad():
@ -304,6 +479,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else:
past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif use_torch_npu and static_cache:
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
past_key_values = static_cache
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
print('[WARN] current staticCache size exceeded, try create new staticCache...')
past_key_values = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
)
else:
past_key_values.reset()
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
@ -320,19 +505,67 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
if use_torch_npu:
past_key_values.position[0] = seq_length + 1
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
start_time = time.time()
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
if use_torch_npu:
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
WARM_UP_SKIP_CNT[0] -= 1
else:
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
@ -348,56 +581,106 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count = seq_length
prefill_time = first_token_time
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
elif not use_torch_npu:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
if use_torch_npu:
past_key_values.position += 1
cuda_graph_runner = None
start_time = time.time()
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
if not use_torch_npu:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
else:
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
WARM_UP_SKIP_CNT[1] -= 1
total_time = time.time() - start_time
tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time
print("")
if not use_torch_npu:
print("")
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
else:
tp_size = get_tensor_parallel_size()
if torch.distributed.get_rank() % tp_size == 0:
rank = f"[rank:{torch.distributed.get_rank()}]"
msg = f"\n{rank} Eval Time\n"
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
msg += rank + f"eval count: {tokens_generated} token(s)\n"
msg += rank + f"eval duration: {total_time:.9f}s\n"
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
print(msg)
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
return tokens