support npu

This commit is contained in:
djw 2025-07-21 12:26:14 +00:00
parent dd0e41b3b8
commit 7d51a13c9b
34 changed files with 14004 additions and 5626 deletions

View file

@ -44,6 +44,10 @@ option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM"
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF) option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF) option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue # feel free to update the Makefile for your architecture and send a pull request or issue
@ -90,6 +94,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
endif () endif ()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV}) set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
else() else()
if(KTRANSFORMERS_USE_NPU)
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
endif()
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
list(APPEND ARCH_FLAGS -mfp16-format=ieee) list(APPEND ARCH_FLAGS -mfp16-format=ieee)
@ -117,37 +124,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected") message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE) if(NOT KTRANSFORMERS_USE_NPU)
set(HAS_AVX512 TRUE) set(HOST_IS_X86 TRUE)
set(__HAS_AMX__ TRUE) set(HAS_AVX512 TRUE)
add_compile_definitions(__x86_64__) set(__HAS_AMX__ TRUE)
# check AVX512 add_compile_definitions(__x86_64__)
execute_process( # check AVX512
COMMAND lscpu execute_process(
OUTPUT_VARIABLE LSCPU_OUTPUT COMMAND lscpu
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE LSCPU_OUTPUT
) OUTPUT_STRIP_TRAILING_WHITESPACE
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}") )
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F) string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1) if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)") message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__) add_compile_definitions(__HAS_AVX512F__)
else() else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F") message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False) set(HAS_AVX512 False)
endif() endif()
# check AMX # check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX) string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1) if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX") message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__) add_compile_definitions(__HAS_AMX__)
else() else()
message(STATUS "Compiler does NOT support AMX") message(STATUS "Compiler does NOT support AMX")
endif() endif()
if (MSVC) if (MSVC)
# instruction set detection for MSVC only # instruction set detection for MSVC only
if (LLAMA_NATIVE) if (LLAMA_NATIVE)
@ -281,6 +289,8 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED) find_package(HIP REQUIRED)
if(HIP_FOUND) if(HIP_FOUND)

View file

@ -0,0 +1,257 @@
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os
import platform
import sys
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import torch.distributed as dist
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
from ktransformers.util import utils
from ktransformers.models.custom_cache import StaticCache
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
torch.npu.config.allow_internal_format = True
ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "npu/DeepSeek-V3-Chat.yaml",
}
torch.npu.set_compile_mode(jit_compile=False)
import sys, signal, faulthandler
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)
def local_chat(
model_path: str | None = None,
optimize_config_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = False,
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_size: int = utils._MAX_CHUNK_SIZE,
q4_gguf_path: str | None = None,
tp: int = 1,
):
utils.USE_NPU_GRAPH = use_cuda_graph
torch.npu.config.allow_internal_format = False
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
local_rank, world_size = setup_model_parallel(tp=tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if use_cuda_graph:
from ktransformers.util import npu_graph_runner
npu_graph_runner.LAYER_ID = config.num_hidden_layers
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0]) if local_rank == 0 else None
optimize_config_path = default_optimize_rules[config.architectures[0]]
print(f'{optimize_config_path=}') if local_rank == 0 else None
else:
optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, q4_gguf_path=q4_gguf_path)
get_absort_weight(model, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except Exception as e:
print(f"generation config can't auto create, make default. Message: {e}")
gen_config = GenerationConfig(
temperature=0.6,
top_p=0.95,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system()
if system == "Windows":
os.system("cls") if local_rank == 0 else None
else:
os.system("clear") if local_rank == 0 else None
print(f"{model=}") if local_rank == 0 else None
batch_size, seq_length = 1, 1024
device_map = model.gguf_loader.tensor_device_map
static_cache = StaticCache(
config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map,
dtype = model.dtype
)
chunk_size = int(chunk_size)
new_chunk_size = min(max(chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != chunk_size:
chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {chunk_size}.')
torch.distributed.barrier()
while True:
if local_rank == 0:
try:
content = input("Chat: ").strip()
except KeyboardInterrupt:
dist.barrier()
print('Exit all ranks with KeyboardInterrupt!')
sys.exit(0)
if content.startswith('"""'): # prefix """
# multi lines input
content = content[3:] + "\n"
while True:
line = input("")
if line.endswith('"""'):
# end multi lines input
line = line[:-3] # suffix """
if line:
content += line + "\n"
break
else:
content += line + "\n"
if content == "":
if prompt_file != None:
content = open(prompt_file, "r").read()
else:
continue
elif os.path.isfile(content):
f = open(content, "r")
content = f.readlines()
f.close()
else:
content = [f"{len(content)},{max_new_tokens},{content}"]
else:
content = [""]
for line in content:
content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
if world_size > 1:
content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE)
all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_sizes, content_size)
max_content_size = max([size.item() for size in all_content_sizes])
padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
padded_content_tensor[:len(content_tensor)] = content_tensor
all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
dist.barrier()
dist.all_gather(all_content_tensors, padded_content_tensor)
content_tensor = all_content_tensors[0][:all_content_sizes[0].item()]
line = bytes(content_tensor.cpu().numpy()).decode()
parts = line.split(",")
input_tokens = int(parts[0])
max_new_tokens = int(parts[1])
line = line[line.index(",", line.index(",") + 1) + 1:]
messages = [{"role": "user", "content": line}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
)
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim,
static_cache=static_cache
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
static_cache=static_cache
)
if __name__ == "__main__":
fire.Fire(local_chat)

View file

@ -16,6 +16,16 @@ try:
from ktransformers.server.balance_serve.settings import sched_ext from ktransformers.server.balance_serve.settings import sched_ext
except: except:
print("no balance_serve") print("no balance_serve")
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
class StaticCache(transformers.StaticCache): class StaticCache(transformers.StaticCache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)`.
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self) Cache.__init__(self)
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
if use_torch_npu:
self.position = [0]
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
if config.architectures[0] == "DeepseekV3ForCausalLM": if config.architectures[0] == "DeepseekV3ForCausalLM":
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size if use_torch_npu:
self.page_size = 128
self.page_size_tensor = torch.tensor(
self.page_size,
dtype=torch.int32,
).npu()
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
else:
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim self.qk_rope_head_dim = config.qk_rope_head_dim
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
target_device = device target_device = device
if target_device not in self.page_table_map: if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device) if use_torch_npu:
for seq_id in range(max_batch_size): page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device) for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
else:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device]) self.page_table_list.append(self.page_table_map[target_device])
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
self.past_tokens[layer_idx] += cache_position.size(0) self.past_tokens[layer_idx] += cache_position.size(0)
#print(cache_position) #print(cache_position)
if self.is_MLA: if self.is_MLA:
page_idx = cache_position // self.page_size if use_torch_npu:
page_offset = cache_position % self.page_size page_idx = cache_position // self.page_size_tensor
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) page_offset = cache_position % self.page_size_tensor
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
page_idx = page_idx + page_idx_offset.unsqueeze(1)
combined = torch.cat([key_states, value_states], dim=-1)
combined = combined.contiguous()
else:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx] return k_out, self.page_table_list[layer_idx]
else: else:
k_out[:, :, cache_position] = key_states k_out[:, :, cache_position] = key_states
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
if self.value_cache[layer_idx] is not None: if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0 self.past_tokens[layer_idx] = 0
if use_torch_npu:
self.position = [0]
def remove_suffix(self, start_pos): def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):

View file

@ -27,8 +27,12 @@ try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
except: except:
pass pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped try:
from ktransformers.operators.triton_attention_prefill import context_attention_fwd from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
except:
Warning("triton not found, if you are using npu, ignore this.")
import os import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled: if flashinfer_enabled:

View file

@ -1,5 +1,8 @@
import torch import torch
import flashinfer try:
import flashinfer
except:
Warning("flashinfer not found, if you are using npu, ignore this.")
import gc import gc
try: try:
from flash_attn import flash_attn_with_kvcache from flash_attn import flash_attn_with_kvcache

View file

@ -5,7 +5,11 @@ Version : 0.2.3
''' '''
import torch import torch
import os import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
try:
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
except:
Warning("triton not found, if you are using npu, ignore this.")
flashinfer_enabled = False flashinfer_enabled = False

View file

@ -14,7 +14,15 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import ctypes import ctypes
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps import KTransformersOps
import vLLMMarlin import vLLMMarlin
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader

View file

@ -16,6 +16,7 @@ from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
from ktransformers.util.utils import set_module, load_weights from ktransformers.util.utils import set_module, load_weights
import itertools import itertools
import copy import copy
from ktransformers.util import utils
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''): def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
for name, child in module._modules.items(): for name, child in module._modules.items():
@ -114,7 +115,7 @@ def translate_model_config(model_config: PretrainedConfig):
return model_config return model_config
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"): def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0", q4_gguf_path=""):
with open(rule_file, 'r', encoding='utf-8') as f: with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
@ -123,15 +124,29 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
model_config = translate_model_config(model_config) model_config = translate_model_config(model_config)
weights_loader = ModelLoaderFactory.create_loader(gguf_path) if q4_gguf_path:
with torch.device("meta"): q4_gguf_loader = GGUFLoader(q4_gguf_path)
inject(module, optimize_config, model_config, weights_loader) utils.Q4_GGUF_LODER = q4_gguf_loader
# pre load lm_head because its big inter result gguf_loader = GGUFLoader(gguf_path, getattr(model_config, "quantize", None))
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device) with torch.device("meta"):
load_weights(module, weights_loader, device=default_device) inject(module, optimize_config, model_config, gguf_loader)
module.gguf_loader = weights_loader # pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
else:
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
load_weights(module, weights_loader, device=default_device)
module.gguf_loader = weights_loader
del_meta(module) del_meta(module)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif torch.xpu.is_available(): elif torch.xpu.is_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

View file

@ -0,0 +1,76 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "npu:0"
prefill_device: "npu:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "npu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "npu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "npu"
prefill_device: "npu"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View file

@ -22,6 +22,10 @@ class ArgumentParser:
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
) )
parser.add_argument("--architectures", type=str, default=self.cfg.model_name) parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--q4_gguf_path", type=str, default=None)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False) parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)

View file

@ -8,6 +8,7 @@ class ConfigArgs(BaseModel):
model_dir: Optional[str] = Field(..., description="Path to model directory") model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file") optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
gguf_path: Optional[str] = Field(None, description="Path of your gguf file") gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
tp: int = Field(None, description="tp size")
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View file

@ -1,4 +1,19 @@
import torch import torch
from torch import nn
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
from ktransformers.util.utils import get_device, get_all_used_cuda_device
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import os
from typing import Optional, List from typing import Optional, List
import asyncio import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
@ -19,6 +34,9 @@ from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False warm_uped = False
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
@ -26,8 +44,15 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args, input_args=None):
self.args = args if use_torch_npu:
self.args = input_args
self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)
if utils.CUR_DEVICE is None:
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
self.args.device = utils.CUR_DEVICE
else:
self.args = args
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
@ -47,7 +72,10 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"): with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config) self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
if use_torch_npu and input_args.optimize_config_path is not None:
optimize_config_path = input_args.optimize_config_path
elif default_args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]] optimize_config_path = default_optimize_rules[config.architectures[0]]
else: else:
optimize_config_path = args.optimize_config_path optimize_config_path = args.optimize_config_path
@ -60,7 +88,14 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all" "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):" " belong to current model):"
) )
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
if use_torch_npu:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)
#提前absorbed
get_absort_weight(self.model, config)
self.model.eval()
else:
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
@ -77,9 +112,92 @@ class KTransformersInterface(TransformersInterface):
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
if use_torch_npu:
self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype=torch.float16, device=self.args.device)
self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype=torch.int32, device=self.args.device)
self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype=torch.float16, device=self.args.device)
self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device=self.args.device)
self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device=self.args.device)
self._infer_lock = asyncio.Lock()
self._infer_lock = asyncio.Lock() self._infer_lock = asyncio.Lock()
def decode_logits_to_token(self, logits: torch.Tensor):
if self.model.generation_config.do_sample:
logits = logits / self.temperature
torch.manual_seed(0)
probs = logits.view(1, self.model.config.vocab_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = self.next_token_fake
torch_npu._npu_topk_topp_sampling(probs, self.top_k, self.top_p, next_token, self.next_token_probs)
last = next_token.squeeze(-1)
else:
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
probs = torch.nn.functional.softmax(logits, dim=-1)
_, last = torch.topk(probs, k=1, dim=-1)
last = last.item()
self.ever_generated_ids.add(last)
return last
def decode_one_tokens_npu(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner
if check_runner(self.args.device):
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.init(self.args.batch_size, self.seq_length)
self.cuda_graph_runner = npu_graph_runner
utils._USE_NPU_GRAPH = True
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=self.args.device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(self.args.device)
logits = self.cuda_graph_runner(
inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
if self.args.use_cuda_graph:
warm_uped = True
if self.use_static_cache:
logits = self.model(
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(self.current_ids, return_dict=False)[0]
self.cache.change_seq_length(1)
logits = logits[0, -1, :]
return self.decode_logits_to_token(logits)
def decode_one_tokens(self): def decode_one_tokens(self):
if use_torch_npu:
return self.decode_one_tokens_npu()
global warm_uped global warm_uped
device_map = self.model.gguf_loader.tensor_device_map device_map = self.model.gguf_loader.tensor_device_map
@ -127,9 +245,145 @@ class KTransformersInterface(TransformersInterface):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad
def prefill_npu(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
device = self.args.device
if is_new:
self.ever_generated_ids.clear()
same_prefix = 0
flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.seq_length = 1
# flat_prev_ids = self.generated_ids.flatten()
# for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
# if flat_input_ids[i] == flat_prev_ids[i]:
# same_prefix += 1
# else:
# break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else:
logger.warning(f"seq_length bigger than cache_lens, killed")
exit(0)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
def chunk_prefill(input_ids, cache_position):
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx = cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1":
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
return self.prefill_npu(self, input_ids, is_new, temperature, top_p, max_tokens, max_completion_tokens)
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
if max_tokens is not None: if max_tokens is not None:
max_completion_tokens = max_tokens max_completion_tokens = max_tokens
@ -144,6 +398,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device device = "cuda:0" if device == "cuda" else device
if use_torch_npu:
device = self.args.device
if is_new: if is_new:
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
@ -159,16 +415,19 @@ class KTransformersInterface(TransformersInterface):
) )
self.seq_length = 1 self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten() if not use_torch_npu:
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): flat_prev_ids = self.generated_ids.flatten()
if flat_input_ids[i] == flat_prev_ids[i]: for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
same_prefix += 1 if flat_input_ids[i] == flat_prev_ids[i]:
else: same_prefix += 1
break else:
break
logger.debug(f"same prefix len: {same_prefix}") logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix) self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix self.seq_length = same_prefix
if use_torch_npu:
self.cache.position[0] = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix] self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:] input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
@ -193,6 +452,8 @@ class KTransformersInterface(TransformersInterface):
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device) cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
if use_torch_npu:
self.cache.position[0] = self.seq_length + 1
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
if not (type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
@ -248,4 +509,18 @@ class KTransformersInterface(TransformersInterface):
decode_time = self.profiler.get_timer_sec('decode'), decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'), prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'), decode_count = self.profiler.get_counter('decode'),
) )
def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def run_async():
result = []
async for chunk in self.inference(local_messages, thread_id, temperature, top_p):
pass
return ""
return loop.run_until_complete(run_async())
finally:
loop.close()

View file

@ -32,6 +32,20 @@ from ktransformers.server.config.log import logger
from ..args import ConfigArgs, default_args from ..args import ConfigArgs, default_args
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
import torch.distributed as dist
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer: class TextStreamer:
@ -191,11 +205,19 @@ class TransformersInterface(BackendInterfaceBase):
# input_ids = self.tokenizer.apply_chat_template( # input_ids = self.tokenizer.apply_chat_template(
# new_messages, return_tensors="pt", add_generation_prompt=True # new_messages, return_tensors="pt", add_generation_prompt=True
# ).to(self.args.device) # ).to(self.args.device)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template if not use_torch_npu:
if input_str.endswith('<think>\n'): input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
input_str = input_str[:-len('<think>\n')] # drop <think> token in chat template
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
else:
logger.debug(f"new_messages: {new_messages}")
input_ids = self.tokenizer.apply_chat_template(
new_messages, add_generation_prompt=True, return_tensors="pt"
)
if (self.last_request_id is not None) and self.last_request_id == thread_id: if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length] x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length]
@ -212,6 +234,8 @@ class TransformersInterface(BackendInterfaceBase):
def append_new_tokens(self, new_tokens: int) -> Optional[str]: def append_new_tokens(self, new_tokens: int) -> Optional[str]:
self.generated_ids[0, self.seq_length] = new_tokens self.generated_ids[0, self.seq_length] = new_tokens
self.seq_length += 1 self.seq_length += 1
if use_torch_npu:
self.cache.position[0] = self.seq_length
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
@staticmethod @staticmethod
@ -273,14 +297,21 @@ class TransformersInterface(BackendInterfaceBase):
top_p = self.model.generation_config.top_p top_p = self.model.generation_config.top_p
if top_p == 0: if top_p == 0:
top_p = 0.0001 top_p = 0.0001
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens, if use_torch_npu:
do_sample=True, generation_config, model_kwargs = self.model._prepare_generation_config(
top_k=self.args.top_k, None, do_sample=True,
top_p=top_p, top_p=top_p, temperature=temperature
temperature=temperature, )
repetition_penalty=self.args.repetition_penalty # change this to modify generate config else:
) generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs self.inputs = inputs
self.logits_warper = self.tf_logits_warper(generation_config) self.logits_warper = self.tf_logits_warper(generation_config)
@ -372,7 +403,10 @@ class TransformersInterface(BackendInterfaceBase):
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
device = input_ids.device if use_torch_npu:
device = self.args.device
else:
device = input_ids.device
if not (type(self) is TransformersInterface): if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu") input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
@ -420,7 +454,12 @@ class TransformersInterface(BackendInterfaceBase):
else: # for's else, if output get max new tokens else: # for's else, if output get max new tokens
yield self.streamer.end(), None yield self.streamer.end(), None
yield "", "length" yield "", "length"
if use_torch_npu and self.args.use_cuda_graph:
utils._USE_NPU_GRAPH = False
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(self.args.device)
npu_graph_runner.destroy()
def check_is_new(self, thread_id: str): def check_is_new(self, thread_id: str):
@ -436,7 +475,87 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
rank = torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
if tp_size == world_size and tp_size > 1:
torch.distributed.barrier()
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
dist.all_gather(all_input_sizes, input_size)
max_input_size = max([size.item() for size in all_input_sizes])
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
dist.all_gather(all_padded_inputs, padded_input_ids)
original_size = all_input_sizes[0].item()
input_ids = all_padded_inputs[0][:, :original_size]
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
if tp_size == world_size and rank != 0:
pass
else:
print(think, end="",flush=True)
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
if tp_size == world_size and rank != 0:
pass
else:
print(t, end="",flush=True)
yield t, finish_reason
if tp_size == world_size and rank != 0:
pass
else:
self.profiler.pause_timer("decode")
self.report_last_time_performance()
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
if use_torch_npu:
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
yield tok
return
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List): if isinstance(local_messages, List):

View file

@ -9,7 +9,7 @@ project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface from ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -17,6 +17,21 @@ from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
import asyncio
from uuid import uuid4
import torch.distributed
import subprocess
import tempfile
import atexit
try:
import torch_npu
from ktransformers.util import utils
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
def mount_app_routes(mount_app: FastAPI): def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil() sql_util = SQLUtil()
@ -100,6 +115,77 @@ def custom_openapi(app):
return app.openapi_schema return app.openapi_schema
def main_npu():
torch.npu.config.allow_internal_format = False
cfg = Config()
arg_parser = ArgumentParser(cfg)
args = arg_parser.parse_args()
utils.USE_NPU_GRAPH = args.use_cuda_graph
new_chunk_size = min(max(args.chunk_size, 512), utils._MAX_CHUNK_SIZE)
if new_chunk_size != args.chunk_size:
args.chunk_size = new_chunk_size
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {args.chunk_size}.')
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg, input_args=args)
args.port += torch.distributed.get_rank()
tp_size = utils.get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if tp_size == world_size and tp_size > 1:
if torch.distributed.get_rank() == 0:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
else:
while True:
try:
context = get_thread_context_manager()
id = str(uuid4())
context.interface.sync_inference("", id)
except Exception as e:
print(f"An error occurred: {e}")
finally:
pass
else:
app = create_app()
custom_openapi(app)
run_api(
app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
def main(): def main():
cfg = Config() cfg = Config()
@ -119,4 +205,7 @@ def main():
) )
if __name__ == "__main__": if __name__ == "__main__":
main() if use_torch_npu:
main_npu()
else:
main()

View file

@ -16,7 +16,7 @@ from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs): def create_interface(config: Config, default_args: ConfigArgs, input_args=None):
if config.backend_type=='transformers': if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
elif config.backend_type == 'exllamav2': elif config.backend_type == 'exllamav2':
@ -27,7 +27,12 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else: else:
raise NotImplementedError(f'{config.backend_type} not implemented') raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
if config.backend_type == 'ktransformers':
GlobalInterface.interface = BackendInterface(default_args, input_args)
else:
GlobalInterface.interface = BackendInterface(default_args)
GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface) GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)
class GlobalContextManager: class GlobalContextManager:

View file

@ -0,0 +1,210 @@
import os
from datetime import timedelta
import torch
try:
import torch_npu
except:
Warning("torch_npu not found, please install torch_npu for NPU support.")
import torch.distributed as dist
_DATA_PARALLEL_SIZE = 0
_TENSOR_PARALLEL_SIZE = 0
_DATA_PARALLEL_GROUP = None
_TENSOR_PARALLEL_RANKS = None
_TENSOR_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
_DATA_PARALLEL_RANKS = None
_GLOBAL_GROUP = None
_LM_HEAD_GROUP = None
def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):
global _DATA_PARALLEL_SIZE
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_RANKS
global _TENSOR_PARALLEL_SIZE
global _TENSOR_PARALLEL_RANKS
global _TENSOR_PARALLEL_GROUP
os.environ["MASTER_ADDR"] = "localhost"
local_rank = int(os.getenv("LOCAL_RANK", '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
torch_npu.npu.set_device(local_rank)
tp_size = tp
dp_size = world_size // tp_size
_DATA_PARALLEL_SIZE = dp_size
_TENSOR_PARALLEL_SIZE = tp_size
torch.set_num_threads(8)
timeout = timedelta(minutes=distributed_timeout_minutes)
print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}")
torch.distributed.init_process_group(
backend='hccl',
world_size=world_size, rank=local_rank
)
print(f"init process group success ------rank is {local_rank}, world_size is {world_size}")
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
for dp_group_id in range(tp_size):
ranks = list(range(dp_group_id, world_size, tp_size))
dp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
if rank in ranks:
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = dp_group
_DATA_PARALLEL_RANKS = ranks
for tp_group_id in range(dp_size):
start_rank = tp_group_id * tp_size
end_rank = (tp_group_id + 1) * tp_size
ranks = list(range(start_rank, end_rank))
tp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
)
if rank in ranks:
global _TENSOR_PARALLEL_GROUP
_TENSOR_PARALLEL_GROUP = tp_group
_TENSOR_PARALLEL_RANKS = ranks
torch.manual_seed(1)
return local_rank, world_size
def get_tensor_parallel_size():
assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set"
return _TENSOR_PARALLEL_SIZE
def get_tensor_parallel_group():
assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized"
return _TENSOR_PARALLEL_GROUP
def get_tensor_parallel_ranks():
assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel ranks is not initialized"
return _TENSOR_PARALLEL_RANKS
def get_data_parallel_size():
assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized"
return _DATA_PARALLEL_SIZE
def get_data_parallel_gloo():
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized"
return _DATA_PARALLEL_GROUP_GLOO
def get_data_parallel_group():
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_data_parallel_ranks():
assert _DATA_PARALLEL_RANKS is not None, "data parallel ranks is not initialized"
return _DATA_PARALLEL_RANKS
def get_global_group():
assert _GLOBAL_GROUP is not None, "global group is not initialized"
return _GLOBAL_GROUP
def get_nccl_options(pg_name, nccl_comm_cfgs):
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
return nccl_options
else:
return None
def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
translate_col_cut_tensors = ["ffn_down", "attn_output"]
translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"]
translate_lm_cut_tensor = ["output"]
tp = get_tensor_parallel_size()
if tp == 1 or weights.shape == torch.Size([1]):
return weights
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank %= tp
assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}"
if any(t in name for t in translate_col_cut_tensors):
if weights.dim() == 1:
return weights
dim = weights.shape[-1]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[:, rank * chunk_size: (rank + 1) * chunk_size]
return output_weights
elif any(t in name for t in translate_row_cut_tensors):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
elif (tp > 1) and (any(t in name for t in translate_lm_cut_tensor)):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=} {world_size=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
else:
return weights
def get_absort_weight(model, config):
local_rank = torch.distributed.get_rank()
tp = get_tensor_parallel_size()
local_rank %= tp
tp_heads = config.num_attention_heads // tp
for i in range(config.num_hidden_layers):
self = model.model.layers[i].self_attn
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(self, "q_absorb", q_absorb)
setattr(self, "out_absorb", out_absorb)
del self.orig_module.kv_b_proj
def allreduce_wrapper(func):
def wrapper(*args, **kwargs):
orig_output = func(*args, **kwargs)
if isinstance(orig_output, tuple):
if get_tensor_parallel_size() > 1:
org_dtype = orig_output[0].dtype
if org_dtype == torch.bfloat16:
dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,
group=get_tensor_parallel_group())
else:
dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
bf_orig_output = orig_output[0].to(dtype=org_dtype)
else:
bf_orig_output = orig_output[0]
else:
bf_orig_output = orig_output[0]
return (bf_orig_output,) + orig_output[1:]
else:
if get_tensor_parallel_size() > 1:
org_dtype = orig_output.dtype
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=torch.float16)
dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=org_dtype)
return orig_output
return wrapper

View file

@ -7,10 +7,19 @@ from typing import Sequence
import os import os
from enum import IntEnum from enum import IntEnum
import torch import torch
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps import KTransformersOps
from safetensors import safe_open from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant if not use_torch_npu:
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import * from ktransformers.util.custom_gguf import *
from safetensors.torch import save_file from safetensors.torch import save_file
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -42,6 +51,7 @@ class SafeTensorLoader(ModelLoader):
tensor_device_map: dict tensor_device_map: dict
def __init__(self, file_path: str): def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path) self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str): def __load_tensor_file_map(self, file_path: str):
@ -84,6 +94,7 @@ class SafeTensorLoader(ModelLoader):
# if not found_safetensor: # if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}") # raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"): def load_tensor(self, key: str, device: str="cpu"):
if translate_name_to_gguf(key) in self.tensor_file_map: if translate_name_to_gguf(key) in self.tensor_file_map:
key = translate_name_to_gguf(key) key = translate_name_to_gguf(key)
@ -96,6 +107,7 @@ class SafeTensorLoader(ModelLoader):
if f is None: if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files") raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key) tensor = f.get_tensor(key)
return tensor.to(device) return tensor.to(device)
def load_experts(self, key: str, device: str="cpu"): def load_experts(self, key: str, device: str="cpu"):
@ -252,20 +264,57 @@ class SafeTensorLoader(ModelLoader):
def has_tensor(self, name: str): def has_tensor(self, name: str):
return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
class W8A8SafeTensorLoader(SafeTensorLoader):
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if "deq_scale" in key:
tensor = torch.from_numpy(
np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))
if "input_scale" in key:
tensor = tensor.to(torch.float16)
if "weight_scale" in key or "weight_offset" in key:
if "ffn" in key:
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(torch.float16)
if "input_offset" in key:
tensor = tensor.to(torch.int8)
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
return tensor.to(device)
def load_dequantized_tensor(self, key: str, device: str = "cpu"):
tensor = self.load_tensor(key, device)
return tensor
class GGUFLoader(ModelLoader): class GGUFLoader(ModelLoader):
tensor_info: dict tensor_info: dict
gguf_path: str gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path} tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict gguf_file_meta: dict
safetensor_loader: SafeTensorLoader safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str): def __init__(self, gguf_path: str, quantize: str = None):
# Check dir exist # Check dir exist
if not os.path.exists(gguf_path): if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path): if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path) gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None safetensor_loader = SafeTensorLoader(gguf_path)
if quantize == "w8a8_dynamic":
safetensor_loader = W8A8SafeTensorLoader(gguf_path)
else:
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
self.tensor_info = {} self.tensor_info = {}
self.gguf_path = gguf_path self.gguf_path = gguf_path

View file

@ -0,0 +1,77 @@
import time
import torch
import torch_npu
import sys
import os
from ktransformers.util.utils import USE_NPU_GRAPH
if USE_NPU_GRAPH:
CAPTURE_PLUGIN_PATH = os.environ.get("CAPTURE_PLUGIN_PATH")
if CAPTURE_PLUGIN_PATH is None:
raise RuntimeError("env CAPTURE_PLUGIN_PATH not exist")
sys.path.append(CAPTURE_PLUGIN_PATH)
from libgraph_capture import graph_capture_init
from libgraph_capture import graph_capture_destroy
from libgraph_capture import graph_capture_begin
from libgraph_capture import graph_capture_end
from libgraph_capture import graph_capture_replay
from libgraph_capture import graph_capture_launch_callback
class NpuGraph:
def init(self):
ret = graph_capture_init()
if ret != 0:
exit()
def destroy(self):
ret = graph_capture_destroy()
if ret != 0:
exit()
def capture_begin(
self,
stream,
capture_error_mode="global"):
torch.npu.synchronize()
torch.npu.empty_cache()
ret = graph_capture_begin(stream, capture_error_mode)
if ret != 0:
exit()
def capture_end(
self,
stream):
ret = graph_capture_end(stream)
if ret != 0:
exit()
def replay(
self,
stream):
ret = graph_capture_replay(stream)
if ret != 0:
exit()
def launch_callback(self, func, data, block, stream):
graph_capture_launch_callback(func, data, block, stream)
class graph:
def __init__(
self,
npu_graph: NpuGraph,
pool,
stream,
capture_error_mode: str = "global"):
self.npu_graph = npu_graph
self.stream = stream.npu_stream
def __enter__(self):
self.npu_graph.capture_begin(self.stream)
def __exit__(self, exc_type, exc_val, exc_tb):
self.npu_graph.capture_end(self.stream)

View file

@ -0,0 +1,218 @@
'''
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from typing import Dict
import acl
import torch
import torch_npu
from torch import nn
import ktransformers.util.npu_graph as npu_graph
from ktransformers.util.utils import CUR_DEVICE
class NPUGraphRunner:
def __init__(self, deviceId):
torch.npu.set_compile_mode(jit_compile=False)
self.deviceId = deviceId
self.enable = False
self.debug = False
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self.tid = None
self.past_key_value = None
def init(self, batch_size, seq_length):
self.tmp_g = npu_graph.NpuGraph()
self.graph = torch.npu.NPUGraph()
self.main_stream = torch_npu.npu.Stream(device=self.deviceId)
self.update_stream = torch_npu.npu.Stream(device=self.deviceId)
self.stream = self.main_stream.npu_stream
self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId)
self.context, ret = acl.rt.get_context(self.deviceId)
if ret != 0:
print("get_context failed! ret: " + str(ret))
exit(-1)
self.exit_flag = False
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
self.workspace = None
if self.tid is None:
def process_callback(args_list):
ins = args_list[0]
ret = acl.rt.set_context(ins.context)
if ret != 0:
print("set_context failed! ret: " + str(ret))
exit(-1)
while True:
acl.rt.process_report(1)
if ins.exit_flag:
break
self.tid, ret = acl.util.start_thread(process_callback, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret = acl.rt.subscribe_report(self.tid, self.stream)
if ret != 0:
print("subscribe_report failed!")
exit(-1)
def destroy(self):
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Begin -------------\n', end='')
self.exit_flag = True
ret = acl.rt.unsubscribe_report(self.tid, self.stream)
if ret != 0:
print("unsubscribe_report failed!")
exit(-1)
self.enable = False
ret = acl.util.stop_thread(self.tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
self.tid = None
self.workspace = None
self.handle = []
self.ifa_param = []
self.event = []
self.first_update = True
del self.graph
self.tmp_g.destroy()
destroy_runner(self.deviceId)
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Finish -------------\n', end='')
def capture(
self,
model,
cur_token,
position_ids,
cache_position,
past_key_values,
main_device,
**kwargs,
) -> None:
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Begin -------------\n', end='')
self.enable = True
self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
self.seq_length = inputs_embeds.size()[1]
self.main_device = main_device
with torch.no_grad():
with torch.npu.graph(self.graph, stream=self.main_stream):
self.logits = model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs)[0]
if past_key_values != None:
past_key_values.change_seq_length(-1)
self.input_buffers = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"cache_position": cache_position,
}
self.output_buffers = {"logits": self.logits}
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Finish -------------\n', end='')
return
def forward(
self,
inputs_embeds,
position_ids,
cache_position,
) -> torch.Tensor:
def ifa_update_sync(param):
with torch.npu.stream(self.update_stream):
for i in range(len(self.handle)):
if self.first_update is False:
q_nope, kvCache, q_pe, kRopeCache, num_heads, \
softmax_scale, layer_idx, attn_output, softmax_lse = self.ifa_param[i]
torch.npu.graph_task_update_begin(self.update_stream, self.handle[i])
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
kvCache,
kvCache,
workspace=self.workspace,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.past_key_value.page_table_list[layer_idx],
block_size=self.past_key_value.page_size,
actual_seq_lengths_kv=self.past_key_value.position,
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(self.update_stream)
self.event[i].record(self.update_stream)
self.ifa_update_tid, ret = acl.util.start_thread(ifa_update_sync, [self])
if ret != 0:
print("start_thread failed!")
exit(-1)
ret1 = acl.rt.memcpy(self.input_buffers["inputs_embeds"].data_ptr(), inputs_embeds.numel() * 2,
inputs_embeds.data_ptr(), inputs_embeds.numel() * 2, 3)
ret2 = acl.rt.memcpy(self.input_buffers["position_ids"].data_ptr(), position_ids.numel() * 8,
position_ids.data_ptr(), position_ids.numel() * 8, 3)
ret3 = acl.rt.memcpy(self.input_buffers["cache_position"].data_ptr(), cache_position.numel() * 8,
cache_position.data_ptr(), cache_position.numel() * 8, 3)
torch_npu.npu.synchronize()
with torch_npu.npu.stream(self.main_stream):
self.graph.replay()
self.first_update = False
ret = acl.util.stop_thread(self.ifa_update_tid)
if ret != 0:
print("stop_thread failed!")
exit(-1)
else:
self.ifa_update_tid = None
return self.output_buffers["logits"]
def launch_callback(self, func, data, block, stream):
self.tmp_g.launch_callback(func, data, block, stream)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
runner_dict = dict()
def check_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
return True
else:
return False
def destroy_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is not None:
runner_dict[deviceId] = None
def get_or_create_runner(deviceId: int):
runner = runner_dict.get(deviceId)
if runner is None:
runner = NPUGraphRunner(deviceId)
runner_dict[deviceId] = runner
return runner

View file

@ -31,8 +31,35 @@ if not torch.xpu.is_available():
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket import socket
import os
import re
import torch.distributed as dist
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
warm_uped = False warm_uped = False
W8A8_ENABLE = False
Q4_GGUF_LODER = None
USE_NPU_GRAPH = None
WARM_UP_SKIP_CNT = [1, 1]
_USE_NPU_GRAPH = False
_MAX_DECODE_PROFILE = 3
CUR_DEVICE = None
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
def get_use_npu_graph():
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
return _USE_NPU_GRAPH
def get_free_ports(n: int, continue_prot: list): def get_free_ports(n: int, continue_prot: list):
sockets = [] sockets = []
ports = [] ports = []
@ -50,6 +77,10 @@ def get_free_ports(n: int, continue_prot: list):
return ports return ports
def get_compute_capability(device:torch.device = None): def get_compute_capability(device:torch.device = None):
if use_torch_npu:
return 0
if torch.cuda.is_available(): if torch.cuda.is_available():
if device is None: if device is None:
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
@ -97,9 +128,16 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list: if "cpu" in all_device_list:
all_device_list.remove("cpu") all_device_list.remove("cpu")
if use_torch_npu:
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
all_device_list = list(all_device_list) all_device_list = list(all_device_list)
return all_device_list return all_device_list
# TODO: support NPU
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"): def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "") prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
@ -109,6 +147,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
key = prefix + name key = prefix + name
translated_key = key translated_key = key
# TODO: Merge all loader. # TODO: Merge all loader.
# I know this is ugly but lets do it for now. # I know this is ugly but lets do it for now.
if isinstance(gguf_loader, SafeTensorLoader): if isinstance(gguf_loader, SafeTensorLoader):
@ -120,7 +159,13 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key: if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
if use_torch_npu:
device = "cpu" if "embd" in translated_key else CUR_DEVICE
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
else:
print(f"loading {translated_key} to {device}")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif torch.xpu.is_available(): elif torch.xpu.is_available():
@ -149,6 +194,8 @@ def sync_all_device(all_device_list):
torch.cuda.synchronize(device) torch.cuda.synchronize(device)
elif "xpu" in device.lower(): elif "xpu" in device.lower():
torch.xpu.synchronize(device) torch.xpu.synchronize(device)
elif use_torch_npu:
torch_npu.synchronize(device)
else: else:
raise RuntimeError("The device {} is not available".format(device)) raise RuntimeError("The device {} is not available".format(device))
@ -228,20 +275,68 @@ def tf_logits_warper(generation_config):
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device if use_torch_npu:
vocabulary_size = model.config.vocab_size
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
torch_device = CUR_DEVICE
else:
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device) inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map) all_cuda_device = get_all_used_cuda_device(device_map)
tokens = [] tokens = []
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
if use_cuda_graph:
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
# torch.cuda.set_device(torch_device)
torch_npu.npu.set_device(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None:
past_key_values.change_seq_length(1)
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
for device in all_cuda_device:
# torch.cuda.synchronize(device)
torch_npu.npu.synchronize(device)
if generation_config.do_sample:
logits = logits / temperature
torch.manual_seed(0)
probs = logits.view(batch_size, vocabulary_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = next_token_fake
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
next_token = next_token.squeeze(-1)
else:
next_token_scores = logits_warper(inputs, logits[:, -1, :])
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True): def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if use_torch_npu:
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
if cuda_graph_runner is None: if cuda_graph_runner is None:
use_cuda_graph = False use_cuda_graph = False
if use_cuda_graph: if use_cuda_graph:
@ -252,6 +347,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch.cuda.set_device(torch_device) torch.cuda.set_device(torch_device)
elif torch.xpu.is_available(): elif torch.xpu.is_available():
torch.xpu.set_device(torch_device) torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else: else:
raise RuntimeError(f"The device: {torch_device} is not available") raise RuntimeError(f"The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
@ -279,6 +376,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else: else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla: if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages) MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all() MLAWrapperSingleton.need_plan_all()
@ -288,11 +386,88 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device) )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
return logits return logits
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
global warm_uped
global _USE_NPU_GRAPH
if use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
npu_graph_runner.init(batch_size, seq_length)
with torch_npu.npu.stream(npu_graph_runner.main_stream):
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
warm_uped = True
_USE_NPU_GRAPH = True
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
cuda_graph_runner = npu_graph_runner
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
cache_position, past_key_values, logits_warper, generation_config,
use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
npu_graph_runner.destroy()
_USE_NPU_GRAPH = False
else:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(torch_device) torch.cuda.set_device(torch_device)
elif torch.xpu.is_available(): elif torch.xpu.is_available():
torch.xpu.set_device(torch_device) torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else: else:
raise RuntimeError(f"The device: {torch_device} is not available") raise RuntimeError(f"The device: {torch_device} is not available")
with torch.no_grad(): with torch.no_grad():
@ -304,6 +479,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else: else:
past_key_values = DynamicNormalCache.from_legacy_cache(None) past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif use_torch_npu and static_cache:
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
past_key_values = static_cache
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
print('[WARN] current staticCache size exceeded, try create new staticCache...')
past_key_values = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
)
else:
past_key_values.reset()
elif mode != 'long_context': elif mode != 'long_context':
past_key_values = StaticCache( past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
@ -320,19 +505,67 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper = tf_logits_warper(generation_config) logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
if use_torch_npu:
past_key_values.position[0] = seq_length + 1
generated_ids = torch.zeros( generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
) )
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int) generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
start_time = time.time() start_time = time.time()
chunk_start = 0 logits = None
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length) def prefill_wrapper(prof=None):
if past_key_values != None: nonlocal logits
past_key_values.cur_idx=cache_position[chunk_start:chunk_end] chunk_start = 0
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) while chunk_start < seq_length:
chunk_start += chunk_size chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
if use_torch_npu:
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
WARM_UP_SKIP_CNT[0] -= 1
else:
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
@ -348,56 +581,106 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count = seq_length prefill_count = seq_length
prefill_time = first_token_time prefill_time = first_token_time
if force_think: if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print("<think>") if force_think:
print(stream.put(next_token.item()), end="", flush=True) print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
elif not use_torch_npu:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(int(next_token)) tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32) cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
if use_torch_npu:
past_key_values.position += 1
cuda_graph_runner = None cuda_graph_runner = None
start_time = time.time() start_time = time.time()
for i in range(1, max_new_tokens):
if use_flashinfer_mla: if not use_torch_npu:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None, for i in range(1, max_new_tokens):
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, if use_flashinfer_mla:
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16) MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
global warm_uped num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
warm_uped = True global warm_uped
cuda_graph_runner = CUDAGraphRunner() if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) warm_uped = True
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device) cuda_graph_runner = CUDAGraphRunner()
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
generated_ids[:, cache_position] = next_token.int() next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
tokens.append(int(next_token)) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
seq_length += 1 generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': seq_length += 1
print(stream.end(), end="", flush=True)
break if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
else:
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
else: else:
print(stream.put(next_token.item()), end="", flush=True) decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
cache_position += 1 WARM_UP_SKIP_CNT[1] -= 1
position_ids = cache_position.unsqueeze(0)
total_time = time.time() - start_time total_time = time.time() - start_time
tokens_generated = len(tokens) tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time tokens_per_second = tokens_generated / total_time
print("") if not use_torch_npu:
print("")
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
else:
tp_size = get_tensor_parallel_size()
if torch.distributed.get_rank() % tp_size == 0:
rank = f"[rank:{torch.distributed.get_rank()}]"
msg = f"\n{rank} Eval Time\n"
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
msg += rank + f"eval count: {tokens_generated} token(s)\n"
msg += rank + f"eval duration: {total_time:.9f}s\n"
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
print(msg)
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
return tokens return tokens

View file

@ -12,6 +12,8 @@ from safetensors.torch import save_file
import re import re
from collections import defaultdict from collections import defaultdict
SKIP_MTP = True
def read_safetensor_keys_from_folder(folder_path)->dict: def read_safetensor_keys_from_folder(folder_path)->dict:
""" """
:param folder_path: folder path :param folder_path: folder path
@ -36,7 +38,7 @@ def read_safetensor_keys_from_folder(folder_path)->dict:
try: try:
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
for key in f.keys(): for key in f.keys():
if "model.layers.61" in key: if SKIP_MTP and "model.layers.61" in key:
# skip MTP layer # skip MTP layer
continue continue
# try: # try:
@ -94,6 +96,28 @@ def combine_tensor_sources(safetensor_path:str, gguf_path:str):
return target_tensor_map, gguf_loader return target_tensor_map, gguf_loader
def combine_w8a8_tensor_sources(safetensor_path: str, gguf_path: str):
gguf_loader = GGUFLoader(gguf_path)
gguf_tensor_file_map = gguf_loader.tensor_file_map
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
# build a map for the key to the tensor
# according to the key, we can get the tensor from the file
target_tensor_map = {}
for key in safetensor_tensor_file_map.keys():
# for all experts, we use the gguf tensor
if ".mlp.experts." in key and "weight_scale" not in key and "weight_offset" not in key:
key = '.'.join(key.split('.')[:5] + key.split('.')[-2:])
translated_key = translate_name(key)
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
elif ".mlp.experts." in key and ("weight_scale" not in key or "weight_offset" not in key):
continue
else:
target_tensor_map[key] = safetensor_tensor_file_map[key]
return target_tensor_map, gguf_loader
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
# Ensure output directory exists # Ensure output directory exists
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
@ -193,6 +217,7 @@ def main():
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3") parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf") parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8") parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
parser.add_argument("--safetensors_format", type=str, help="Safetensors format", default="fp8")
# print all the arguments # print all the arguments
print("All the arguments:") print("All the arguments:")
@ -204,8 +229,18 @@ def main():
safetensor_path = args.safetensor_path safetensor_path = args.safetensor_path
gguf_path = args.gguf_path gguf_path = args.gguf_path
output_path = args.output_path output_path = args.output_path
safetensors_format = args.safetensors_format
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) match safetensors_format:
case "w8a8":
global SKIP_MTP
SKIP_MTP = False
target_tensor_map, gguf_loader = combine_w8a8_tensor_sources(safetensor_path, gguf_path)
case "fp8":
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
case _:
raise ValueError(f"Unsupported safetensors format: {safetensor_path}")
write_combined_tensor(target_tensor_map, output_path, gguf_loader) write_combined_tensor(target_tensor_map, output_path, gguf_loader)
return return

View file

@ -673,10 +673,29 @@ if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:
ext_modules.append( ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
) )
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
elif torch.xpu.is_available(): elif torch.xpu.is_available():
ext_modules = [ ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
] ]
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
elif KTRANSFORMERS_BUILD_NPU: elif KTRANSFORMERS_BUILD_NPU:
ext_modules = [ ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
@ -687,11 +706,10 @@ elif KTRANSFORMERS_BUILD_NPU:
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
) )
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
install_requires=triton_dep,
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=ext_modules
)

File diff suppressed because it is too large Load diff

5866
third_party/llamafile/iqk_mul_mat_arm.inc vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,10 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm80.cpp
// Copyright 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __aarch64__
#define iqk_mul_mat iqk_mul_mat_arm80
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm80
#include "iqk_mul_mat.inc"
#endif // __aarch64__

4925
third_party/llamafile/iqk_mul_mat_x86.inc vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,204 +1,7 @@
// Adapted from #if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp // 使用 x86 版本
// Copyrigth 2024 Mozilla Foundation. #include "sgemm_arm.cpp"
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else #else
// Intel Xeon Skylake+ (2015-) // 使用 ARM 版本
sgemm = llamafile_sgemm_amd_avx512f; #include "sgemm_x86.cpp"
mixmul = llamafile_mixmul_amd_avx512f; #endif
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
long hwcap = getauxval(AT_HWCAP);
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// e.g. Apple M1, Raspberry Pi 5
sgemm = llamafile_sgemm_arm82;
mixmul = llamafile_mixmul_arm82;
iqk_mixmul = iqk_mul_mat_moe_arm82;
} else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
}
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}

204
third_party/llamafile/sgemm_arm.cpp vendored Normal file
View file

@ -0,0 +1,204 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
// long hwcap = getauxval(AT_HWCAP);
// if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
// (hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
// (hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// // e.g. Apple M1, Raspberry Pi 5
// sgemm = llamafile_sgemm_arm82;
// mixmul = llamafile_mixmul_arm82;
// iqk_mixmul = iqk_mul_mat_moe_arm82;
// } else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
// }
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}

204
third_party/llamafile/sgemm_x86.cpp vendored Normal file
View file

@ -0,0 +1,204 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
// #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
// #include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
long hwcap = getauxval(AT_HWCAP);
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// e.g. Apple M1, Raspberry Pi 5
sgemm = llamafile_sgemm_arm82;
mixmul = llamafile_mixmul_arm82;
iqk_mixmul = iqk_mul_mat_moe_arm82;
} else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
}
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}

View file

@ -5,6 +5,7 @@
#ifdef __aarch64__ #ifdef __aarch64__
#define llamafile_mixmul llamafile_mixmul_arm80 #define llamafile_mixmul llamafile_mixmul_arm80
#define iqk_mul_mat iqk_mul_mat_arm80
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
/** /**

View file

@ -1,361 +1,7 @@
// Adapted from #if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc // 使用 x86 版本
// Copyrigth 2024 Mozilla Foundation. #include "tinyblas_cpu_sgemm_arm.inc"
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else #else
return NOT_SUPPORTED; // 使用 ARM 版本
#endif #include "tinyblas_cpu_sgemm_x86.inc"
} #endif
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
if (Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}

View file

@ -0,0 +1,471 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
#include <arm_neon.h>
#include <ostream>
#include <iostream>
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
void SgemmHelperN1Neon2(long m, long n, long k, const float16_t* A, long lda, const float16_t* B, long ldb,
TC* C, long ldc, int ith, int nth) {
// A m * k B n * k c n * m
const long NVL = 8;
long kk = k / (NVL * 4);
kk = kk * (NVL * 4);
long length = (m / nth) + (ith < (m % nth) ? 1 : 0);
long startRow = ith * (m / nth) + (ith < (m % nth) ? ith : (m % nth));
long endRow = startRow + length;
for (long i = startRow; i < endRow; i ++) {
const float16_t* tA = A + i * lda;
float32x4_t c0 = vdupq_n_f32(0);
float32x4_t c1 = vdupq_n_f32(0);
float32x4_t c2 = vdupq_n_f32(0);
float32x4_t c3 = vdupq_n_f32(0);
float32x4_t c4 = vdupq_n_f32(0);
float32x4_t c5 = vdupq_n_f32(0);
float32x4_t c6 = vdupq_n_f32(0);
float32x4_t c7 = vdupq_n_f32(0);
for (long j = 0; j < kk; j += NVL * 4) {
__builtin_prefetch(tA + 192, 0, 0);
float16x8_t a0 = vld1q_f16(tA + j);
float16x8_t b0 = vld1q_f16(B + j);
c0 = vfmlalq_low_f16(c0, a0, b0);
c1 = vfmlalq_high_f16(c1, a0, b0);
float16x8_t a1 = vld1q_f16(tA + j + NVL);
float16x8_t b1 = vld1q_f16(B + j + NVL);
c2 = vfmlalq_low_f16(c2, a1, b1);
c3 = vfmlalq_high_f16(c3, a1, b1);
float16x8_t a2 = vld1q_f16(tA + j + NVL * 2);
float16x8_t b2 = vld1q_f16(B + j + NVL * 2);
c4 = vfmlalq_low_f16(c4, a2, b2);
c5 = vfmlalq_high_f16(c5, a2, b2);
float16x8_t a3 = vld1q_f16(tA + j + NVL * 3);
float16x8_t b3 = vld1q_f16(B + j + NVL * 3);
c6 = vfmlalq_low_f16(c6, a3, b3);
c7 = vfmlalq_high_f16(c7, a3, b3);
}
if (k - kk >= NVL * 2) {
float16x8_t a0 = vld1q_f16(tA + kk);
float16x8_t b0 = vld1q_f16(B + kk);
c0 = vfmlalq_low_f16(c0, a0, b0);
c1 = vfmlalq_high_f16(c1, a0, b0);
float16x8_t a1 = vld1q_f16(tA + kk + NVL);
float16x8_t b1 = vld1q_f16(B + kk + NVL);
c2 = vfmlalq_low_f16(c2, a1, b1);
c3 = vfmlalq_high_f16(c3, a1, b1);
kk += NVL * 2;
}
if (k - kk >= NVL) {
float16x8_t a = vld1q_f16(tA + kk);
float16x8_t b = vld1q_f16(B + kk);
c0 = vfmlalq_low_f16(c0, a, b);
c1 = vfmlalq_high_f16(c1, a, b);
kk += NVL;
}
TC sum = 0.0f;
for (long j = kk; j < k; j ++) {
sum += (float32_t)tA[j] * (float32_t)B[j];
}
c0 = vaddq_f32(c0, c1);
c2 = vaddq_f32(c2, c3);
c4 = vaddq_f32(c4, c5);
c6 = vaddq_f32(c6, c7);
c0 = vaddq_f32(c0, c2);
c4 = vaddq_f32(c4, c6);
sum += vaddvq_f32(c0) + vaddvq_f32(c4);
C[i] = sum;
}
return;
}
template <typename TC>
void SgemmHelperN1(long m, long n, long k, const ggml_fp16_t* A_, long lda, const ggml_fp16_t* B_, long ldb,
TC* C, long ldc, int ith, int nth) {
// A m * k B n * k c n * m
float16_t *A = (float16_t*)A_;
float16_t *B = (float16_t*)B_;
long rowsPerThread = m / nth;
long startRow = ith * rowsPerThread;
long endRow = (ith == nth - 1) ? m : startRow + rowsPerThread;
for (long i = startRow; i < endRow; i ++) {
TC sum = 0.0f;
for (long j = 0; j < k; j ++) {
sum += (float32_t)A[i * lda + j] * (float32_t)B[j];
}
C[i] = sum;
}
return;
}
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise) {
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
return true;
}
return NOT_SUPPORTED;
}
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise) {
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
return true;
}
return NOT_SUPPORTED;
}
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}

View file

@ -0,0 +1,361 @@
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
/*
moonll
more Btype accept
}*/
if (Ctype == GGML_TYPE_F32){
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}