support npu

This commit is contained in:
djw 2025-07-21 12:26:14 +00:00
parent dd0e41b3b8
commit 7d51a13c9b
34 changed files with 14004 additions and 5626 deletions

View file

@ -44,6 +44,10 @@ option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM"
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
@ -90,6 +94,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
endif ()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
else()
if(KTRANSFORMERS_USE_NPU)
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
endif()
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
@ -117,37 +124,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
if(NOT KTRANSFORMERS_USE_NPU)
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
if (MSVC)
# instruction set detection for MSVC only
if (LLAMA_NATIVE)
@ -281,6 +289,8 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX)
if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)

View file

@ -0,0 +1,257 @@
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os
import platform
import sys
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import torch.distributed as dist
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
from ktransformers.util import utils
from ktransformers.models.custom_cache import StaticCache
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
torch.npu.config.allow_internal_format = True
ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "npu/DeepSeek-V3-Chat.yaml",
}
torch.npu.set_compile_mode(jit_compile=False)
import sys, signal, faulthandler
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)
def local_chat(
model_path: str | None = None,
optimize_config_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = False,
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_size: int = utils._MAX_CHUNK_SIZE,
q4_gguf_path: str | None = None,
tp: int = 1,
):
utils.USE_NPU_GRAPH = use_cuda_graph
torch.npu.config.allow_internal_format = False
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
local_rank, world_size = setup_model_parallel(tp=tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if use_cuda_graph:
from ktransformers.util import npu_graph_runner
npu_graph_runner.LAYER_ID = config.num_hidden_layers
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0]) if local_rank == 0 else None
optimize_config_path = default_optimize_rules[config.architectures[0]]
print(f'{optimize_config_path=}') if local_rank == 0 else None
else:
optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, q4_gguf_path=q4_gguf_path)
get_absort_weight(model, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except Exception as e:
print(f"generation config can't auto create, make default. Message: {e}")
gen_config = GenerationConfig(
temperature=0.6,
top_p=0.95,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system()
if system == "Windows":
os.system("cls") if local_rank == 0 else None
else:
os.system("clear") if local_rank == 0 else None
print(f"{model=}") if local_rank == 0 else None
batch_size, seq_length = 1, 1024
device_map = model.gguf_loader.tensor_device_map
static_cache = StaticCache(
config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map,
dtype = model.dtype
)
chunk_size = int(chunk_size)
new_chunk_size = min(max(chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != chunk_size:
chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {chunk_size}.')
torch.distributed.barrier()
while True:
if local_rank == 0:
try:
content = input("Chat: ").strip()
except KeyboardInterrupt:
dist.barrier()
print('Exit all ranks with KeyboardInterrupt!')
sys.exit(0)
if content.startswith('"""'): # prefix """
# multi lines input
content = content[3:] + "\n"
while True:
line = input("")
if line.endswith('"""'):
# end multi lines input
line = line[:-3] # suffix """
if line:
content += line + "\n"
break
else:
content += line + "\n"
if content == "":
if prompt_file != None:
content = open(prompt_file, "r").read()
else:
continue
elif os.path.isfile(content):
f = open(content, "r")
content = f.readlines()
f.close()
else:
content = [f"{len(content)},{max_new_tokens},{content}"]
else:
content = [""]
for line in content:
content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
if world_size > 1:
content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE)
all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_sizes, content_size)
max_content_size = max([size.item() for size in all_content_sizes])
padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
padded_content_tensor[:len(content_tensor)] = content_tensor
all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_tensors, padded_content_tensor)
content_tensor = all_content_tensors[0][:all_content_sizes[0].item()]
line = bytes(content_tensor.cpu().numpy()).decode()
parts = line.split(",")
input_tokens = int(parts[0])
max_new_tokens = int(parts[1])
line = line[line.index(",", line.index(",") + 1) + 1:]
messages = [{"role": "user", "content": line}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
)
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim,
static_cache=static_cache
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
static_cache=static_cache
)
if __name__ == "__main__":
fire.Fire(local_chat)

View file

@ -16,6 +16,16 @@ try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self)
self.max_batch_size = max_batch_size
if use_torch_npu:
self.position = [0]
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
if config.architectures[0] == "DeepseekV3ForCausalLM":
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
if use_torch_npu:
self.page_size = 128
self.page_size_tensor = torch.tensor(
self.page_size,
dtype=torch.int32,
).npu()
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
else:
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
target_device = device
if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
if use_torch_npu:
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
else:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device])
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
self.past_tokens[layer_idx] += cache_position.size(0)
#print(cache_position)
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
if use_torch_npu:
page_idx = cache_position // self.page_size_tensor
page_offset = cache_position % self.page_size_tensor
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
page_idx = page_idx + page_idx_offset.unsqueeze(1)
combined = torch.cat([key_states, value_states], dim=-1)
combined = combined.contiguous()
else:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
if use_torch_npu:
self.position = [0]
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):

View file

@ -27,8 +27,12 @@ try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
try:
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
except:
Warning("triton not found, if you are using npu, ignore this.")
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:

View file

@ -1,5 +1,8 @@
import torch
import flashinfer
try:
import flashinfer
except:
Warning("flashinfer not found, if you are using npu, ignore this.")
import gc
try:
from flash_attn import flash_attn_with_kvcache

View file

@ -5,7 +5,11 @@ Version : 0.2.3
'''
import torch
import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
try:
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
except:
Warning("triton not found, if you are using npu, ignore this.")
flashinfer_enabled = False

View file

@ -14,7 +14,15 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes
import torch
from torch import Tensor, nn
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader

View file

@ -16,6 +16,7 @@ from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
from ktransformers.util.utils import set_module, load_weights
import itertools
import copy
from ktransformers.util import utils
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
for name, child in module._modules.items():
@ -114,7 +115,7 @@ def translate_model_config(model_config: PretrainedConfig):
return model_config
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0", q4_gguf_path=""):
with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
@ -123,15 +124,29 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
model_config = translate_model_config(model_config)
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
if q4_gguf_path:
q4_gguf_loader = GGUFLoader(q4_gguf_path)
utils.Q4_GGUF_LODER = q4_gguf_loader
gguf_loader = GGUFLoader(gguf_path, getattr(model_config, "quantize", None))
with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
else:
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
del_meta(module)
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

View file

@ -0,0 +1,76 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "npu:0"
prefill_device: "npu:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "npu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "npu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "npu"
prefill_device: "npu"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -22,6 +22,10 @@ class ArgumentParser:
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--q4_gguf_path", type=str, default=None)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)

View file

@ -8,6 +8,7 @@ class ConfigArgs(BaseModel):
model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
tp: int = Field(None, description="tp size")
class Config:
protected_namespaces = ()

View file

@ -1,4 +1,19 @@
import torch
from torch import nn
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
from ktransformers.util.utils import get_device, get_all_used_cuda_device
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import os
from typing import Optional, List
import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
@ -19,6 +34,9 @@ from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext):
@ -26,8 +44,15 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args):
self.args = args
def __init__(self, args: ConfigArgs = default_args, input_args=None):
if use_torch_npu:
self.args = input_args
self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
self.args.device = utils.CUR_DEVICE
else:
self.args = args
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
@ -47,7 +72,10 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
if use_torch_npu and input_args.optimize_config_path is not None:
optimize_config_path = input_args.optimize_config_path
elif default_args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
@ -60,7 +88,14 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
if use_torch_npu:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)
#提前absorbed
get_absort_weight(self.model, config)
self.model.eval()
else:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
@ -77,9 +112,92 @@ class KTransformersInterface(TransformersInterface):
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
if use_torch_npu:
self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype=torch.float16, device=self.args.device)
self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype=torch.int32, device=self.args.device)
self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype=torch.float16, device=self.args.device)
self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device=self.args.device)
self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device=self.args.device)
self._infer_lock = asyncio.Lock()
self._infer_lock = asyncio.Lock()
def decode_logits_to_token(self, logits: torch.Tensor):
if self.model.generation_config.do_sample:
logits = logits / self.temperature
torch.manual_seed(0)
probs = logits.view(1, self.model.config.vocab_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = self.next_token_fake
torch_npu._npu_topk_topp_sampling(probs, self.top_k, self.top_p, next_token, self.next_token_probs)
last = next_token.squeeze(-1)
else:
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
probs = torch.nn.functional.softmax(logits, dim=-1)
_, last = torch.topk(probs, k=1, dim=-1)
last = last.item()
self.ever_generated_ids.add(last)
return last
def decode_one_tokens_npu(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner
if check_runner(self.args.device):
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.init(self.args.batch_size, self.seq_length)
self.cuda_graph_runner = npu_graph_runner
utils._USE_NPU_GRAPH = True
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=self.args.device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(self.args.device)
logits = self.cuda_graph_runner(
inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
if self.args.use_cuda_graph:
warm_uped = True
if self.use_static_cache:
logits = self.model(
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(self.current_ids, return_dict=False)[0]
self.cache.change_seq_length(1)
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
def decode_one_tokens(self):
if use_torch_npu:
return self.decode_one_tokens_npu()
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
@ -127,9 +245,145 @@ class KTransformersInterface(TransformersInterface):
return self.logits_to_token(logits)
@torch.no_grad
def prefill_npu(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
device = self.args.device
if is_new:
self.ever_generated_ids.clear()
same_prefix = 0
flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.seq_length = 1
# flat_prev_ids = self.generated_ids.flatten()
# for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
# if flat_input_ids[i] == flat_prev_ids[i]:
# same_prefix += 1
# else:
# break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else:
logger.warning(f"seq_length bigger than cache_lens, killed")
exit(0)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
def chunk_prefill(input_ids, cache_position):
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx = cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1":
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
return self.prefill_npu(self, input_ids, is_new, temperature, top_p, max_tokens, max_completion_tokens)
input_ids_length = input_ids.shape[-1]
if max_tokens is not None:
max_completion_tokens = max_tokens
@ -144,6 +398,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
if use_torch_npu:
device = self.args.device
if is_new:
self.ever_generated_ids.clear()
@ -159,16 +415,19 @@ class KTransformersInterface(TransformersInterface):
)
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
if not use_torch_npu:
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
if use_torch_npu:
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
@ -193,6 +452,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
if use_torch_npu:
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface):
@ -248,4 +509,18 @@ class KTransformersInterface(TransformersInterface):
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)
)
def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def run_async():
result = []
async for chunk in self.inference(local_messages, thread_id, temperature, top_p):
pass
return ""
return loop.run_until_complete(run_async())
finally:
loop.close()

View file

@ -32,6 +32,20 @@ from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import torch.distributed as dist
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:
@ -191,11 +205,19 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
if not use_torch_npu:
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
else:
logger.debug(f"new_messages: {new_messages}")
input_ids = self.tokenizer.apply_chat_template(
new_messages, add_generation_prompt=True, return_tensors="pt"
)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
@ -212,6 +234,8 @@ class TransformersInterface(BackendInterfaceBase):
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length += 1
if use_torch_npu:
self.cache.position[0] = self.seq_length
return self.streamer.put(new_tokens)
@staticmethod
@ -273,14 +297,21 @@ class TransformersInterface(BackendInterfaceBase):
top_p = self.model.generation_config.top_p
if top_p == 0:
top_p = 0.0001
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
if use_torch_npu:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, do_sample=True,
top_p=top_p, temperature=temperature
)
else:
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.logits_warper = self.tf_logits_warper(generation_config)
@ -372,7 +403,10 @@ class TransformersInterface(BackendInterfaceBase):
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
device = input_ids.device
if use_torch_npu:
device = self.args.device
else:
device = input_ids.device
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
@ -420,7 +454,12 @@ class TransformersInterface(BackendInterfaceBase):
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
if use_torch_npu and self.args.use_cuda_graph:
utils._USE_NPU_GRAPH = False
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.destroy()
def check_is_new(self, thread_id: str):
@ -436,7 +475,87 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
rank = torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if tp_size == world_size and tp_size > 1:
torch.distributed.barrier()
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
dist.all_gather(all_input_sizes, input_size)
max_input_size = max([size.item() for size in all_input_sizes])
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
dist.all_gather(all_padded_inputs, padded_input_ids)
original_size = all_input_sizes[0].item()
input_ids = all_padded_inputs[0][:, :original_size]
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
if tp_size == world_size and rank != 0:
pass
else:
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
if tp_size == world_size and rank != 0:
pass
else:
print(t, end="",flush=True)
yield t, finish_reason
if tp_size == world_size and rank != 0:
pass
else:
self.profiler.pause_timer("decode")
self.report_last_time_performance()
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
yield tok
return
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):

View file

@ -9,7 +9,7 @@ project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -17,6 +17,21 @@ from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
import asyncio
from uuid import uuid4
import torch.distributed
import subprocess
import tempfile
import atexit
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
@ -100,6 +115,77 @@ def custom_openapi(app):
return app.openapi_schema
def main_npu():
torch.npu.config.allow_internal_format = False
cfg = Config()
arg_parser = ArgumentParser(cfg)
args = arg_parser.parse_args()
utils.USE_NPU_GRAPH = args.use_cuda_graph
new_chunk_size = min(max(args.chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != args.chunk_size:
args.chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {args.chunk_size}.')
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg, input_args=args)
args.port += torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if tp_size == world_size and tp_size > 1:
if torch.distributed.get_rank() == 0:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
else:
while True:
try:
context = get_thread_context_manager()
id = str(uuid4())
context.interface.sync_inference("", id)
except Exception as e:
print(f"An error occurred: {e}")
finally:
pass
else:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
def main():
cfg = Config()
@ -119,4 +205,7 @@ def main():
)
if __name__ == "__main__":
main()
if use_torch_npu:
main_npu()
else:
main()

View file

@ -16,7 +16,7 @@ from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs):
def create_interface(config: Config, default_args: ConfigArgs, input_args=None):
if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
elif config.backend_type == 'exllamav2':
@ -27,7 +27,12 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else:
raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
if config.backend_type == 'ktransformers':
GlobalInterface.interface = BackendInterface(default_args, input_args)
else:
GlobalInterface.interface = BackendInterface(default_args)
GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)
class GlobalContextManager:

View file

@ -0,0 +1,210 @@
import os
from datetime import timedelta
import torch
try:
import torch_npu
except:
Warning("torch_npu not found, please install torch_npu for NPU support.")
import torch.distributed as dist
_DATA_PARALLEL_SIZE = 0
_TENSOR_PARALLEL_SIZE = 0
_DATA_PARALLEL_GROUP = None
_TENSOR_PARALLEL_RANKS = None
_TENSOR_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
_DATA_PARALLEL_RANKS = None
_GLOBAL_GROUP = None
_LM_HEAD_GROUP = None
def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):
global _DATA_PARALLEL_SIZE
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_RANKS
global _TENSOR_PARALLEL_SIZE
global _TENSOR_PARALLEL_RANKS
global _TENSOR_PARALLEL_GROUP
os.environ["MASTER_ADDR"] = "localhost"
local_rank = int(os.getenv("LOCAL_RANK", '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
torch_npu.npu.set_device(local_rank)
tp_size = tp
dp_size = world_size // tp_size
_DATA_PARALLEL_SIZE = dp_size
_TENSOR_PARALLEL_SIZE = tp_size
torch.set_num_threads(8)
timeout = timedelta(minutes=distributed_timeout_minutes)
print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}")
torch.distributed.init_process_group(
backend='hccl',
world_size=world_size, rank=local_rank
)
print(f"init process group success ------rank is {local_rank}, world_size is {world_size}")
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
for dp_group_id in range(tp_size):
ranks = list(range(dp_group_id, world_size, tp_size))
dp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
if rank in ranks:
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = dp_group
_DATA_PARALLEL_RANKS = ranks
for tp_group_id in range(dp_size):
start_rank = tp_group_id * tp_size
end_rank = (tp_group_id + 1) * tp_size
ranks = list(range(start_rank, end_rank))
tp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
)
if rank in ranks:
global _TENSOR_PARALLEL_GROUP
_TENSOR_PARALLEL_GROUP = tp_group
_TENSOR_PARALLEL_RANKS = ranks
torch.manual_seed(1)
return local_rank, world_size
def get_tensor_parallel_size():
assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set"
return _TENSOR_PARALLEL_SIZE
def get_tensor_parallel_group():
assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized"
return _TENSOR_PARALLEL_GROUP
def get_tensor_parallel_ranks():
assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel ranks is not initialized"
return _TENSOR_PARALLEL_RANKS
def get_data_parallel_size():
assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized"
return _DATA_PARALLEL_SIZE
def get_data_parallel_gloo():
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized"
return _DATA_PARALLEL_GROUP_GLOO
def get_data_parallel_group():
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_data_parallel_ranks():
assert _DATA_PARALLEL_RANKS is not None, "data parallel ranks is not initialized"
return _DATA_PARALLEL_RANKS
def get_global_group():
assert _GLOBAL_GROUP is not None, "global group is not initialized"
return _GLOBAL_GROUP
def get_nccl_options(pg_name, nccl_comm_cfgs):
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
return nccl_options
else:
return None
def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
translate_col_cut_tensors = ["ffn_down", "attn_output"]
translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"]
translate_lm_cut_tensor = ["output"]
tp = get_tensor_parallel_size()
if tp == 1 or weights.shape == torch.Size([1]):
return weights
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank %= tp
assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}"
if any(t in name for t in translate_col_cut_tensors):
if weights.dim() == 1:
return weights
dim = weights.shape[-1]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[:, rank * chunk_size: (rank + 1) * chunk_size]
return output_weights
elif any(t in name for t in translate_row_cut_tensors):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
elif (tp > 1) and (any(t in name for t in translate_lm_cut_tensor)):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=} {world_size=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
else:
return weights
def get_absort_weight(model, config):
local_rank = torch.distributed.get_rank()
tp = get_tensor_parallel_size()
local_rank %= tp
tp_heads = config.num_attention_heads // tp
for i in range(config.num_hidden_layers):
self = model.model.layers[i].self_attn
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(self, "q_absorb", q_absorb)
setattr(self, "out_absorb", out_absorb)
del self.orig_module.kv_b_proj
def allreduce_wrapper(func):
def wrapper(*args, **kwargs):
orig_output = func(*args, **kwargs)
if isinstance(orig_output, tuple):
if get_tensor_parallel_size() > 1:
org_dtype = orig_output[0].dtype
if org_dtype == torch.bfloat16:
dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,
group=get_tensor_parallel_group())
else:
dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
bf_orig_output = orig_output[0].to(dtype=org_dtype)
else:
bf_orig_output = orig_output[0]
else:
bf_orig_output = orig_output[0]
return (bf_orig_output,) + orig_output[1:]
else:
if get_tensor_parallel_size() > 1:
org_dtype = orig_output.dtype
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=torch.float16)
dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=org_dtype)
return orig_output
return wrapper

View file

@ -7,10 +7,19 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
if not use_torch_npu:
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import *
from safetensors.torch import save_file
from abc import ABC, abstractmethod
@ -42,6 +51,7 @@ class SafeTensorLoader(ModelLoader):
tensor_device_map: dict
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
@ -84,6 +94,7 @@ class SafeTensorLoader(ModelLoader):
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
if translate_name_to_gguf(key) in self.tensor_file_map:
key = translate_name_to_gguf(key)
@ -96,6 +107,7 @@ class SafeTensorLoader(ModelLoader):
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def load_experts(self, key: str, device: str="cpu"):
@ -252,20 +264,57 @@ class SafeTensorLoader(ModelLoader):
def has_tensor(self, name: str):
return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
class W8A8SafeTensorLoader(SafeTensorLoader):
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if "deq_scale" in key:
tensor = torch.from_numpy(
np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))
if "input_scale" in key:
tensor = tensor.to(torch.float16)
if "weight_scale" in key or "weight_offset" in key:
if "ffn" in key:
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(torch.float16)
if "input_offset" in key:
tensor = tensor.to(torch.int8)
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
return tensor.to(device)
def load_dequantized_tensor(self, key: str, device: str = "cpu"):
tensor = self.load_tensor(key, device)
return tensor
class GGUFLoader(ModelLoader):
tensor_info: dict
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
def __init__(self, gguf_path: str, quantize: str = None):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
safetensor_loader = SafeTensorLoader(gguf_path)
if quantize == "w8a8_dynamic":
safetensor_loader = W8A8SafeTensorLoader(gguf_path)
else:
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
self.tensor_info = {}
self.gguf_path = gguf_path

View file

@ -0,0 +1,77 @@
import time
import torch
import torch_npu
import sys
import os
from ktransformers.util.utils import USE_NPU_GRAPH
if USE_NPU_GRAPH:
CAPTURE_PLUGIN_PATH = os.environ.get("CAPTURE_PLUGIN_PATH")
if CAPTURE_PLUGIN_PATH is None:
raise RuntimeError("env CAPTURE_PLUGIN_PATH not exist")
sys.path.append(CAPTURE_PLUGIN_PATH)
from libgraph_capture import graph_capture_init
from libgraph_capture import graph_capture_destroy
from libgraph_capture import graph_capture_begin
from libgraph_capture import graph_capture_end
from libgraph_capture import graph_capture_replay
from libgraph_capture import graph_capture_launch_callback
class NpuGraph:
def init(self):
ret = graph_capture_init()
if ret != 0:
exit()
def destroy(self):
ret = graph_capture_destroy()
if ret != 0:
exit()
def capture_begin(
self,
stream,
capture_error_mode="global"):
torch.npu.synchronize()
torch.npu.empty_cache()
ret = graph_capture_begin(stream, capture_error_mode)
if ret != 0:
exit()
def capture_end(
self,
stream):
ret = graph_capture_end(stream)
if ret != 0:
exit()
def replay(
self,
stream):
ret = graph_capture_replay(stream)
if ret != 0:
exit()
def launch_callback(self, func, data, block, stream):
graph_capture_launch_callback(func, data, block, stream)
class graph:
def __init__(
self,
npu_graph: NpuGraph,
pool,
stream,
capture_error_mode: str = "global"):
self.npu_graph = npu_graph
self.stream = stream.npu_stream
def __enter__(self):
self.npu_graph.capture_begin(self.stream)
def __exit__(self, exc_type, exc_val, exc_tb):
self.npu_graph.capture_end(self.stream)

View file

@ -0,0 +1,218 @@
'''
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from typing import Dict
import acl
import torch
import torch_npu
from torch import nn
import ktransformers.util.npu_graph as npu_graph
from ktransformers.util.utils import CUR_DEVICE
class NPUGraphRunner:
def __init__(self, deviceId):
torch.npu.set_compile_mode(jit_compile=False)
self.deviceId = deviceId
self.enable = False
self.debug = False
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self.tid = None
self.past_key_value = None
def init(self, batch_size, seq_length):
self.tmp_g = npu_graph.NpuGraph()
self.graph = torch.npu.NPUGraph()
self.main_stream = torch_npu.npu.Stream(device=self.deviceId)
self.update_stream = torch_npu.npu.Stream(device=self.deviceId)
self.stream = self.main_stream.npu_stream
self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId)
self.context, ret = acl.rt.get_context(self.deviceId)
if ret != 0:
print("get_context failed! ret: " + str(ret))
exit(-1)
self.exit_flag = False
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
self.workspace = None
if self.tid is None:
def process_callback(args_list):
ins = args_list[0]
ret = acl.rt.set_context(ins.context)
if ret != 0:
print("set_context failed! ret: " + str(ret))
exit(-1)
while True:
acl.rt.process_report(1)
if ins.exit_flag:
break
self.tid, ret = acl.util.start_thread(process_callback, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret = acl.rt.subscribe_report(self.tid, self.stream)
if ret != 0:
print("subscribe_report failed!")
exit(-1)
def destroy(self):
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Begin -------------\n', end='')
self.exit_flag = True
ret = acl.rt.unsubscribe_report(self.tid, self.stream)
if ret != 0:
print("unsubscribe_report failed!")
exit(-1)
self.enable = False
ret = acl.util.stop_thread(self.tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
self.tid = None
self.workspace = None
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
del self.graph
self.tmp_g.destroy()
destroy_runner(self.deviceId)
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Finish -------------\n', end='')
def capture(
self,
model,
cur_token,
position_ids,
cache_position,
past_key_values,
main_device,
**kwargs,
) -> None:
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Begin -------------\n', end='')
self.enable = True
self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
self.seq_length = inputs_embeds.size()[1]
self.main_device = main_device
with torch.no_grad():
with torch.npu.graph(self.graph, stream=self.main_stream):
self.logits = model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs)[0]
if past_key_values != None:
past_key_values.change_seq_length(-1)
self.input_buffers = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"cache_position": cache_position,
}
self.output_buffers = {"logits": self.logits}
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Finish -------------\n', end='')
return
def forward(
self,
inputs_embeds,
position_ids,
cache_position,
) -> torch.Tensor:
def ifa_update_sync(param):
with torch.npu.stream(self.update_stream):
for i in range(len(self.handle)):
if self.first_update is False:
q_nope, kvCache, q_pe, kRopeCache, num_heads, \
softmax_scale, layer_idx, attn_output, softmax_lse = self.ifa_param[i]
torch.npu.graph_task_update_begin(self.update_stream, self.handle[i])
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
kvCache,
kvCache,
workspace=self.workspace,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.past_key_value.page_table_list[layer_idx],
block_size=self.past_key_value.page_size,
actual_seq_lengths_kv=self.past_key_value.position,
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(self.update_stream)
self.event[i].record(self.update_stream)
self.ifa_update_tid, ret = acl.util.start_thread(ifa_update_sync, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret1 = acl.rt.memcpy(self.input_buffers["inputs_embeds"].data_ptr(), inputs_embeds.numel() * 2,
inputs_embeds.data_ptr(), inputs_embeds.numel() * 2, 3)
ret2 = acl.rt.memcpy(self.input_buffers["position_ids"].data_ptr(), position_ids.numel() * 8,
position_ids.data_ptr(), position_ids.numel() * 8, 3)
ret3 = acl.rt.memcpy(self.input_buffers["cache_position"].data_ptr(), cache_position.numel() * 8,
cache_position.data_ptr(), cache_position.numel() * 8, 3)
torch_npu.npu.synchronize()
with torch_npu.npu.stream(self.main_stream):
self.graph.replay()
self.first_update = False
ret = acl.util.stop_thread(self.ifa_update_tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
else:
self.ifa_update_tid = None
return self.output_buffers["logits"]
def launch_callback(self, func, data, block, stream):
self.tmp_g.launch_callback(func, data, block, stream)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
runner_dict = dict()
def check_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
return True
else:
return False
def destroy_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is not None:
runner_dict[deviceId] = None
def get_or_create_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
runner = NPUGraphRunner(deviceId)
runner_dict[deviceId] = runner
return runner

View file

@ -31,8 +31,35 @@ if not torch.xpu.is_available():
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
import os
import re
import torch.distributed as dist
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
warm_uped = False
W8A8_ENABLE = False
Q4_GGUF_LODER = None
USE_NPU_GRAPH = None
WARM_UP_SKIP_CNT = [1, 1]
_USE_NPU_GRAPH = False
_MAX_DECODE_PROFILE = 3
CUR_DEVICE = None
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
def get_use_npu_graph():
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
return _USE_NPU_GRAPH
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
@ -50,6 +77,10 @@ def get_free_ports(n: int, continue_prot: list):
return ports
def get_compute_capability(device:torch.device = None):
if use_torch_npu:
return 0
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
@ -97,9 +128,16 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
if use_torch_npu:
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
all_device_list = list(all_device_list)
return all_device_list
# TODO: support NPU
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
@ -109,6 +147,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
key = prefix + name
translated_key = key
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if isinstance(gguf_loader, SafeTensorLoader):
@ -120,7 +159,13 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
if use_torch_npu:
device = "cpu" if "embd" in translated_key else CUR_DEVICE
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
else:
print(f"loading {translated_key} to {device}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
@ -149,6 +194,8 @@ def sync_all_device(all_device_list):
torch.cuda.synchronize(device)
elif "xpu" in device.lower():
torch.xpu.synchronize(device)
elif use_torch_npu:
torch_npu.synchronize(device)
else:
raise RuntimeError("The device {} is not available".format(device))
@ -228,20 +275,68 @@ def tf_logits_warper(generation_config):
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
if use_torch_npu:
vocabulary_size = model.config.vocab_size
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
torch_device = CUR_DEVICE
else:
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = []
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
if use_cuda_graph:
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
# torch.cuda.set_device(torch_device)
torch_npu.npu.set_device(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None:
past_key_values.change_seq_length(1)
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
for device in all_cuda_device:
# torch.cuda.synchronize(device)
torch_npu.npu.synchronize(device)
if generation_config.do_sample:
logits = logits / temperature
torch.manual_seed(0)
probs = logits.view(batch_size, vocabulary_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = next_token_fake
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
next_token = next_token.squeeze(-1)
else:
next_token_scores = logits_warper(inputs, logits[:, -1, :])
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if use_torch_npu:
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
if cuda_graph_runner is None:
use_cuda_graph = False
if use_cuda_graph:
@ -252,6 +347,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
@ -279,6 +376,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all()
@ -288,11 +386,88 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
return logits
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
global warm_uped
global _USE_NPU_GRAPH
if use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
npu_graph_runner.init(batch_size, seq_length)
with torch_npu.npu.stream(npu_graph_runner.main_stream):
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
warm_uped = True
_USE_NPU_GRAPH = True
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
cuda_graph_runner = npu_graph_runner
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
cache_position, past_key_values, logits_warper, generation_config,
use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
npu_graph_runner.destroy()
_USE_NPU_GRAPH = False
else:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
with torch.no_grad():
@ -304,6 +479,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else:
past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif use_torch_npu and static_cache:
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
past_key_values = static_cache
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
print('[WARN] current staticCache size exceeded, try create new staticCache...')
past_key_values = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
)
else:
past_key_values.reset()
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
@ -320,19 +505,67 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
if use_torch_npu:
past_key_values.position[0] = seq_length + 1
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
start_time = time.time()
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
if use_torch_npu:
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
WARM_UP_SKIP_CNT[0] -= 1
else:
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
@ -348,56 +581,106 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count = seq_length
prefill_time = first_token_time
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
elif not use_torch_npu:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
if use_torch_npu:
past_key_values.position += 1
cuda_graph_runner = None
start_time = time.time()
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
if not use_torch_npu:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
else:
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
WARM_UP_SKIP_CNT[1] -= 1
total_time = time.time() - start_time
tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time
print("")
if not use_torch_npu:
print("")
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
else:
tp_size = get_tensor_parallel_size()
if torch.distributed.get_rank() % tp_size == 0:
rank = f"[rank:{torch.distributed.get_rank()}]"
msg = f"\n{rank} Eval Time\n"
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
msg += rank + f"eval count: {tokens_generated} token(s)\n"
msg += rank + f"eval duration: {total_time:.9f}s\n"
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
print(msg)
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
return tokens

View file

@ -12,6 +12,8 @@ from safetensors.torch import save_file
import re
from collections import defaultdict
SKIP_MTP = True
def read_safetensor_keys_from_folder(folder_path)->dict:
"""
:param folder_path: folder path
@ -36,7 +38,7 @@ def read_safetensor_keys_from_folder(folder_path)->dict:
try:
with safe_open(file_path, framework="pt") as f:
for key in f.keys():
if "model.layers.61" in key:
if SKIP_MTP and "model.layers.61" in key:
# skip MTP layer
continue
# try:
@ -94,6 +96,28 @@ def combine_tensor_sources(safetensor_path:str, gguf_path:str):
return target_tensor_map, gguf_loader
def combine_w8a8_tensor_sources(safetensor_path: str, gguf_path: str):
gguf_loader = GGUFLoader(gguf_path)
gguf_tensor_file_map = gguf_loader.tensor_file_map
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
# build a map for the key to the tensor
# according to the key, we can get the tensor from the file
target_tensor_map = {}
for key in safetensor_tensor_file_map.keys():
# for all experts, we use the gguf tensor
if ".mlp.experts." in key and "weight_scale" not in key and "weight_offset" not in key:
key = '.'.join(key.split('.')[:5] + key.split('.')[-2:])
translated_key = translate_name(key)
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
elif ".mlp.experts." in key and ("weight_scale" not in key or "weight_offset" not in key):
continue
else:
target_tensor_map[key] = safetensor_tensor_file_map[key]
return target_tensor_map, gguf_loader
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)
@ -193,6 +217,7 @@ def main():
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
parser.add_argument("--safetensors_format", type=str, help="Safetensors format", default="fp8")
# print all the arguments
print("All the arguments:")
@ -204,8 +229,18 @@ def main():
safetensor_path = args.safetensor_path
gguf_path = args.gguf_path
output_path = args.output_path
safetensors_format = args.safetensors_format
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
match safetensors_format:
case "w8a8":
global SKIP_MTP
SKIP_MTP = False
target_tensor_map, gguf_loader = combine_w8a8_tensor_sources(safetensor_path, gguf_path)
case "fp8":
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
case _:
raise ValueError(f"Unsupported safetensors format: {safetensor_path}")
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
return

View file

@ -673,10 +673,29 @@ if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
elif torch.xpu.is_available():
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
]
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
elif KTRANSFORMERS_BUILD_NPU:
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
@ -687,11 +706,10 @@ elif KTRANSFORMERS_BUILD_NPU:
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)

File diff suppressed because it is too large Load diff

5866
third_party/llamafile/iqk_mul_mat_arm.inc vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,10 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm80.cpp
// Copyright 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __aarch64__
#define iqk_mul_mat iqk_mul_mat_arm80
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm80
#include "iqk_mul_mat.inc"
#endif // __aarch64__

4925
third_party/llamafile/iqk_mul_mat_x86.inc vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,204 +1,7 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
// 使用 x86 版本
#include "sgemm_arm.cpp"
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
long hwcap = getauxval(AT_HWCAP);
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// e.g. Apple M1, Raspberry Pi 5
sgemm = llamafile_sgemm_arm82;
mixmul = llamafile_mixmul_arm82;
iqk_mixmul = iqk_mul_mat_moe_arm82;
} else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
}
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}
// 使用 ARM 版本
#include "sgemm_x86.cpp"
#endif

204
third_party/llamafile/sgemm_arm.cpp vendored Normal file
View file

@ -0,0 +1,204 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
// long hwcap = getauxval(AT_HWCAP);
// if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
// (hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
// (hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// // e.g. Apple M1, Raspberry Pi 5
// sgemm = llamafile_sgemm_arm82;
// mixmul = llamafile_mixmul_arm82;
// iqk_mixmul = iqk_mul_mat_moe_arm82;
// } else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
// }
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}

204
third_party/llamafile/sgemm_x86.cpp vendored Normal file
View file

@ -0,0 +1,204 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
long hwcap = getauxval(AT_HWCAP);
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// e.g. Apple M1, Raspberry Pi 5
sgemm = llamafile_sgemm_arm82;
mixmul = llamafile_mixmul_arm82;
iqk_mixmul = iqk_mul_mat_moe_arm82;
} else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
}
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}

View file

@ -5,6 +5,7 @@
#ifdef __aarch64__
#define llamafile_mixmul llamafile_mixmul_arm80
#define iqk_mul_mat iqk_mul_mat_arm80
#include "tinyblas_cpu_mixmul.inc"
/**

View file

@ -1,361 +1,7 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
// 使用 x86 版本
#include "tinyblas_cpu_sgemm_arm.inc"
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
if (Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}
// 使用 ARM 版本
#include "tinyblas_cpu_sgemm_x86.inc"
#endif

View file

@ -0,0 +1,471 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
#include <arm_neon.h>
#include <ostream>
#include <iostream>
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
void SgemmHelperN1Neon2(long m, long n, long k, const float16_t* A, long lda, const float16_t* B, long ldb,
TC* C, long ldc, int ith, int nth) {
// A m * k B n * k c n * m
const long NVL = 8;
long kk = k / (NVL * 4);
kk = kk * (NVL * 4);
long length = (m / nth) + (ith < (m % nth) ? 1 : 0);
long startRow = ith * (m / nth) + (ith < (m % nth) ? ith : (m % nth));
long endRow = startRow + length;
for (long i = startRow; i < endRow; i ++) {
const float16_t* tA = A + i * lda;
float32x4_t c0 = vdupq_n_f32(0);
float32x4_t c1 = vdupq_n_f32(0);
float32x4_t c2 = vdupq_n_f32(0);
float32x4_t c3 = vdupq_n_f32(0);
float32x4_t c4 = vdupq_n_f32(0);
float32x4_t c5 = vdupq_n_f32(0);
float32x4_t c6 = vdupq_n_f32(0);
float32x4_t c7 = vdupq_n_f32(0);
for (long j = 0; j < kk; j += NVL * 4) {
__builtin_prefetch(tA + 192, 0, 0);
float16x8_t a0 = vld1q_f16(tA + j);
float16x8_t b0 = vld1q_f16(B + j);
c0 = vfmlalq_low_f16(c0, a0, b0);
c1 = vfmlalq_high_f16(c1, a0, b0);
float16x8_t a1 = vld1q_f16(tA + j + NVL);
float16x8_t b1 = vld1q_f16(B + j + NVL);
c2 = vfmlalq_low_f16(c2, a1, b1);
c3 = vfmlalq_high_f16(c3, a1, b1);
float16x8_t a2 = vld1q_f16(tA + j + NVL * 2);
float16x8_t b2 = vld1q_f16(B + j + NVL * 2);
c4 = vfmlalq_low_f16(c4, a2, b2);
c5 = vfmlalq_high_f16(c5, a2, b2);
float16x8_t a3 = vld1q_f16(tA + j + NVL * 3);
float16x8_t b3 = vld1q_f16(B + j + NVL * 3);
c6 = vfmlalq_low_f16(c6, a3, b3);
c7 = vfmlalq_high_f16(c7, a3, b3);
}
if (k - kk >= NVL * 2) {
float16x8_t a0 = vld1q_f16(tA + kk);
float16x8_t b0 = vld1q_f16(B + kk);
c0 = vfmlalq_low_f16(c0, a0, b0);
c1 = vfmlalq_high_f16(c1, a0, b0);
float16x8_t a1 = vld1q_f16(tA + kk + NVL);
float16x8_t b1 = vld1q_f16(B + kk + NVL);
c2 = vfmlalq_low_f16(c2, a1, b1);
c3 = vfmlalq_high_f16(c3, a1, b1);
kk += NVL * 2;
}
if (k - kk >= NVL) {
float16x8_t a = vld1q_f16(tA + kk);
float16x8_t b = vld1q_f16(B + kk);
c0 = vfmlalq_low_f16(c0, a, b);
c1 = vfmlalq_high_f16(c1, a, b);
kk += NVL;
}
TC sum = 0.0f;
for (long j = kk; j < k; j ++) {
sum += (float32_t)tA[j] * (float32_t)B[j];
}
c0 = vaddq_f32(c0, c1);
c2 = vaddq_f32(c2, c3);
c4 = vaddq_f32(c4, c5);
c6 = vaddq_f32(c6, c7);
c0 = vaddq_f32(c0, c2);
c4 = vaddq_f32(c4, c6);
sum += vaddvq_f32(c0) + vaddvq_f32(c4);
C[i] = sum;
}
return;
}
template <typename TC>
void SgemmHelperN1(long m, long n, long k, const ggml_fp16_t* A_, long lda, const ggml_fp16_t* B_, long ldb,
TC* C, long ldc, int ith, int nth) {
// A m * k B n * k c n * m
float16_t *A = (float16_t*)A_;
float16_t *B = (float16_t*)B_;
long rowsPerThread = m / nth;
long startRow = ith * rowsPerThread;
long endRow = (ith == nth - 1) ? m : startRow + rowsPerThread;
for (long i = startRow; i < endRow; i ++) {
TC sum = 0.0f;
for (long j = 0; j < k; j ++) {
sum += (float32_t)A[i * lda + j] * (float32_t)B[j];
}
C[i] = sum;
}
return;
}
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise) {
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
return true;
}
return NOT_SUPPORTED;
}
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise) {
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
return true;
}
return NOT_SUPPORTED;
}
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}

View file

@ -0,0 +1,361 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
if (Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}