mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
support npu
This commit is contained in:
parent
dd0e41b3b8
commit
7d51a13c9b
34 changed files with 14004 additions and 5626 deletions
|
@ -44,6 +44,10 @@ option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM"
|
|||
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
|
||||
endif()
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
|
@ -90,6 +94,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
|
|||
endif ()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
|
||||
else()
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
|
||||
endif()
|
||||
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
||||
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
||||
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
||||
|
@ -117,37 +124,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
||||
message(STATUS "x86 detected")
|
||||
set(HOST_IS_X86 TRUE)
|
||||
set(HAS_AVX512 TRUE)
|
||||
set(__HAS_AMX__ TRUE)
|
||||
add_compile_definitions(__x86_64__)
|
||||
# check AVX512
|
||||
execute_process(
|
||||
COMMAND lscpu
|
||||
OUTPUT_VARIABLE LSCPU_OUTPUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
|
||||
if(NOT KTRANSFORMERS_USE_NPU)
|
||||
set(HOST_IS_X86 TRUE)
|
||||
set(HAS_AVX512 TRUE)
|
||||
set(__HAS_AMX__ TRUE)
|
||||
add_compile_definitions(__x86_64__)
|
||||
# check AVX512
|
||||
execute_process(
|
||||
COMMAND lscpu
|
||||
OUTPUT_VARIABLE LSCPU_OUTPUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
|
||||
|
||||
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
|
||||
|
||||
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
|
||||
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
|
||||
add_compile_definitions(__HAS_AVX512F__)
|
||||
else()
|
||||
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
|
||||
set(HAS_AVX512 False)
|
||||
endif()
|
||||
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
|
||||
|
||||
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
|
||||
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
|
||||
add_compile_definitions(__HAS_AVX512F__)
|
||||
else()
|
||||
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
|
||||
set(HAS_AVX512 False)
|
||||
endif()
|
||||
|
||||
# check AMX
|
||||
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
|
||||
|
||||
if(COMPILER_SUPPORTS_AMX GREATER -1)
|
||||
message(STATUS "Compiler supports AMX")
|
||||
add_compile_definitions(__HAS_AMX__)
|
||||
else()
|
||||
message(STATUS "Compiler does NOT support AMX")
|
||||
endif()
|
||||
# check AMX
|
||||
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
|
||||
|
||||
if(COMPILER_SUPPORTS_AMX GREATER -1)
|
||||
message(STATUS "Compiler supports AMX")
|
||||
add_compile_definitions(__HAS_AMX__)
|
||||
else()
|
||||
message(STATUS "Compiler does NOT support AMX")
|
||||
endif()
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
if (LLAMA_NATIVE)
|
||||
|
@ -281,6 +289,8 @@ if (WIN32)
|
|||
include_directories("$ENV{CUDA_PATH}/include")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
elseif (UNIX)
|
||||
|
||||
|
||||
if (KTRANSFORMERS_USE_ROCM)
|
||||
find_package(HIP REQUIRED)
|
||||
if(HIP_FOUND)
|
||||
|
|
257
ktransformers/local_chat_npu.py
Normal file
257
ktransformers/local_chat_npu.py
Normal file
|
@ -0,0 +1,257 @@
|
|||
"""
|
||||
Description :
|
||||
Author : Boxin Zhang, Azure-Tang
|
||||
Version : 0.1.0
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, project_dir)
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
import torch.distributed as dist
|
||||
|
||||
import logging
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
TextStreamer,
|
||||
)
|
||||
import json
|
||||
import fire
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
|
||||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "npu/DeepSeek-V3-Chat.yaml",
|
||||
}
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
|
||||
import sys, signal, faulthandler
|
||||
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)
|
||||
|
||||
|
||||
def local_chat(
|
||||
model_path: str | None = None,
|
||||
optimize_config_path: str = None,
|
||||
gguf_path: str | None = None,
|
||||
max_new_tokens: int = 1000,
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
use_cuda_graph: bool = False,
|
||||
prompt_file : str | None = None,
|
||||
mode: str = "normal",
|
||||
force_think: bool = False,
|
||||
chunk_size: int = utils._MAX_CHUNK_SIZE,
|
||||
q4_gguf_path: str | None = None,
|
||||
tp: int = 1,
|
||||
):
|
||||
utils.USE_NPU_GRAPH = use_cuda_graph
|
||||
torch.npu.config.allow_internal_format = False
|
||||
torch.set_grad_enabled(False)
|
||||
Config().cpu_infer = cpu_infer
|
||||
|
||||
local_rank, world_size = setup_model_parallel(tp=tp)
|
||||
if utils.CUR_DEVICE is None:
|
||||
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
if use_cuda_graph:
|
||||
from ktransformers.util import npu_graph_runner
|
||||
npu_graph_runner.LAYER_ID = config.num_hidden_layers
|
||||
if mode == 'long_context':
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if (
|
||||
"Qwen2Moe" in config.architectures[0]
|
||||
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Llama" in config.architectures[0]:
|
||||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
if optimize_config_path is None:
|
||||
if config.architectures[0] in default_optimize_rules:
|
||||
print("using default_optimize_rule for", config.architectures[0]) if local_rank == 0 else None
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
print(f'{optimize_config_path=}') if local_rank == 0 else None
|
||||
else:
|
||||
optimize_config_path = input(
|
||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||
)
|
||||
|
||||
if gguf_path is None:
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, q4_gguf_path=q4_gguf_path)
|
||||
get_absort_weight(model, config)
|
||||
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
except Exception as e:
|
||||
print(f"generation config can't auto create, make default. Message: {e}")
|
||||
gen_config = GenerationConfig(
|
||||
temperature=0.6,
|
||||
top_p=0.95,
|
||||
do_sample=True
|
||||
)
|
||||
model.generation_config = gen_config
|
||||
# model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
if model.generation_config.pad_token_id is None:
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
model.eval()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
system = platform.system()
|
||||
if system == "Windows":
|
||||
os.system("cls") if local_rank == 0 else None
|
||||
else:
|
||||
os.system("clear") if local_rank == 0 else None
|
||||
|
||||
print(f"{model=}") if local_rank == 0 else None
|
||||
|
||||
batch_size, seq_length = 1, 1024
|
||||
device_map = model.gguf_loader.tensor_device_map
|
||||
static_cache = StaticCache(
|
||||
config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map,
|
||||
dtype = model.dtype
|
||||
)
|
||||
chunk_size = int(chunk_size)
|
||||
new_chunk_size = min(max(chunk_size, 512), utils._MAX_CHUNK_SIZE)
|
||||
if new_chunk_size != chunk_size:
|
||||
chunk_size = new_chunk_size
|
||||
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {chunk_size}.')
|
||||
|
||||
torch.distributed.barrier()
|
||||
while True:
|
||||
if local_rank == 0:
|
||||
try:
|
||||
content = input("Chat: ").strip()
|
||||
except KeyboardInterrupt:
|
||||
dist.barrier()
|
||||
print('Exit all ranks with KeyboardInterrupt!')
|
||||
sys.exit(0)
|
||||
if content.startswith('"""'): # prefix """
|
||||
# multi lines input
|
||||
content = content[3:] + "\n"
|
||||
while True:
|
||||
line = input("")
|
||||
if line.endswith('"""'):
|
||||
# end multi lines input
|
||||
line = line[:-3] # suffix """
|
||||
if line:
|
||||
content += line + "\n"
|
||||
break
|
||||
else:
|
||||
content += line + "\n"
|
||||
|
||||
if content == "":
|
||||
if prompt_file != None:
|
||||
content = open(prompt_file, "r").read()
|
||||
else:
|
||||
continue
|
||||
elif os.path.isfile(content):
|
||||
f = open(content, "r")
|
||||
content = f.readlines()
|
||||
f.close()
|
||||
else:
|
||||
content = [f"{len(content)},{max_new_tokens},{content}"]
|
||||
else:
|
||||
content = [""]
|
||||
|
||||
for line in content:
|
||||
content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
|
||||
if world_size > 1:
|
||||
content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE)
|
||||
all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
|
||||
dist.barrier()
|
||||
dist.all_gather(all_content_sizes, content_size)
|
||||
max_content_size = max([size.item() for size in all_content_sizes])
|
||||
|
||||
padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE)
|
||||
padded_content_tensor[:len(content_tensor)] = content_tensor
|
||||
|
||||
all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)]
|
||||
dist.barrier()
|
||||
dist.all_gather(all_content_tensors, padded_content_tensor)
|
||||
content_tensor = all_content_tensors[0][:all_content_sizes[0].item()]
|
||||
line = bytes(content_tensor.cpu().numpy()).decode()
|
||||
|
||||
parts = line.split(",")
|
||||
input_tokens = int(parts[0])
|
||||
max_new_tokens = int(parts[1])
|
||||
line = line[line.index(",", line.index(",") + 1) + 1:]
|
||||
|
||||
messages = [{"role": "user", "content": line}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if force_think:
|
||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
|
||||
input_tensor = torch.cat(
|
||||
[input_tensor, token_thinks], dim=1
|
||||
)
|
||||
if mode == 'long_context':
|
||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim,
|
||||
static_cache=static_cache
|
||||
)
|
||||
else:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,
|
||||
static_cache=static_cache
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(local_chat)
|
|
@ -16,6 +16,16 @@ try:
|
|||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
except:
|
||||
print("no balance_serve")
|
||||
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
class StaticCache(transformers.StaticCache):
|
||||
"""
|
||||
Static Cache class to be used with `torch.compile(model)`.
|
||||
|
@ -37,6 +47,10 @@ class StaticCache(transformers.StaticCache):
|
|||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
|
||||
Cache.__init__(self)
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
if use_torch_npu:
|
||||
self.position = [0]
|
||||
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
|
@ -56,8 +70,18 @@ class StaticCache(transformers.StaticCache):
|
|||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
|
||||
if use_torch_npu:
|
||||
self.page_size = 128
|
||||
self.page_size_tensor = torch.tensor(
|
||||
self.page_size,
|
||||
dtype=torch.int32,
|
||||
).npu()
|
||||
self.max_pages_per_batch = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size * self.max_batch_size
|
||||
else:
|
||||
self.page_size = 64
|
||||
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
|
||||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
|
@ -71,9 +95,14 @@ class StaticCache(transformers.StaticCache):
|
|||
target_device = device
|
||||
|
||||
if target_device not in self.page_table_map:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||||
if use_torch_npu:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
|
||||
else:
|
||||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||||
for seq_id in range(max_batch_size):
|
||||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||||
self.page_table_map[target_device] = page_table
|
||||
|
||||
self.page_table_list.append(self.page_table_map[target_device])
|
||||
|
@ -140,11 +169,24 @@ class StaticCache(transformers.StaticCache):
|
|||
self.past_tokens[layer_idx] += cache_position.size(0)
|
||||
#print(cache_position)
|
||||
if self.is_MLA:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
if use_torch_npu:
|
||||
page_idx = cache_position // self.page_size_tensor
|
||||
page_offset = cache_position % self.page_size_tensor
|
||||
|
||||
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
|
||||
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
|
||||
|
||||
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
|
||||
page_idx = page_idx + page_idx_offset.unsqueeze(1)
|
||||
|
||||
combined = torch.cat([key_states, value_states], dim=-1)
|
||||
combined = combined.contiguous()
|
||||
else:
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
|
@ -178,6 +220,9 @@ class StaticCache(transformers.StaticCache):
|
|||
if self.value_cache[layer_idx] is not None:
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.past_tokens[layer_idx] = 0
|
||||
|
||||
if use_torch_npu:
|
||||
self.position = [0]
|
||||
|
||||
def remove_suffix(self, start_pos):
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
|
|
|
@ -27,8 +27,12 @@ try:
|
|||
from flash_attn import flash_attn_func
|
||||
except:
|
||||
pass
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
|
||||
try:
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
|
||||
except:
|
||||
Warning("triton not found, if you are using npu, ignore this.")
|
||||
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import torch
|
||||
import flashinfer
|
||||
try:
|
||||
import flashinfer
|
||||
except:
|
||||
Warning("flashinfer not found, if you are using npu, ignore this.")
|
||||
import gc
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
|
|
@ -5,7 +5,11 @@ Version : 0.2.3
|
|||
'''
|
||||
import torch
|
||||
import os
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
|
||||
try:
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
except:
|
||||
Warning("triton not found, if you are using npu, ignore this.")
|
||||
|
||||
flashinfer_enabled = False
|
||||
|
||||
|
|
|
@ -14,7 +14,15 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|||
import ctypes
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
if not torch.xpu.is_available():
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
if not torch.xpu.is_available() and not use_torch_npu:
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
|
|||
from ktransformers.util.utils import set_module, load_weights
|
||||
import itertools
|
||||
import copy
|
||||
from ktransformers.util import utils
|
||||
|
||||
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
|
||||
for name, child in module._modules.items():
|
||||
|
@ -114,7 +115,7 @@ def translate_model_config(model_config: PretrainedConfig):
|
|||
return model_config
|
||||
|
||||
|
||||
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
|
||||
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0", q4_gguf_path=""):
|
||||
with open(rule_file, 'r', encoding='utf-8') as f:
|
||||
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
|
||||
|
@ -123,15 +124,29 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
|
|||
|
||||
model_config = translate_model_config(model_config)
|
||||
|
||||
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
|
||||
with torch.device("meta"):
|
||||
inject(module, optimize_config, model_config, weights_loader)
|
||||
# pre load lm_head because its big inter result
|
||||
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
|
||||
load_weights(module, weights_loader, device=default_device)
|
||||
module.gguf_loader = weights_loader
|
||||
if q4_gguf_path:
|
||||
q4_gguf_loader = GGUFLoader(q4_gguf_path)
|
||||
utils.Q4_GGUF_LODER = q4_gguf_loader
|
||||
gguf_loader = GGUFLoader(gguf_path, getattr(model_config, "quantize", None))
|
||||
with torch.device("meta"):
|
||||
inject(module, optimize_config, model_config, gguf_loader)
|
||||
# pre load lm_head because its big inter result
|
||||
load_weights(module.lm_head, gguf_loader, "lm_head.")
|
||||
load_weights(module, gguf_loader)
|
||||
module.gguf_loader = gguf_loader
|
||||
|
||||
else:
|
||||
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
|
||||
with torch.device("meta"):
|
||||
inject(module, optimize_config, model_config, weights_loader)
|
||||
# pre load lm_head because its big inter result
|
||||
load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device)
|
||||
load_weights(module, weights_loader, device=default_device)
|
||||
module.gguf_loader = weights_loader
|
||||
del_meta(module)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
|
@ -0,0 +1,76 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "npu:0"
|
||||
prefill_device: "npu:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "npu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "npu"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
|
@ -22,6 +22,10 @@ class ArgumentParser:
|
|||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||
)
|
||||
parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
|
||||
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--q4_gguf_path", type=str, default=None)
|
||||
|
||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||
|
|
|
@ -8,6 +8,7 @@ class ConfigArgs(BaseModel):
|
|||
model_dir: Optional[str] = Field(..., description="Path to model directory")
|
||||
optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file")
|
||||
gguf_path: Optional[str] = Field(None, description="Path of your gguf file")
|
||||
tp: int = Field(None, description="tp size")
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
|
|
@ -1,4 +1,19 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
|
||||
from ktransformers.util.utils import get_device, get_all_used_cuda_device
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
import os
|
||||
|
||||
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||
|
@ -19,6 +34,9 @@ from typing import Optional
|
|||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
|
@ -26,8 +44,15 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
|
||||
|
||||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
def __init__(self, args: ConfigArgs = default_args, input_args=None):
|
||||
if use_torch_npu:
|
||||
self.args = input_args
|
||||
self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)
|
||||
if utils.CUR_DEVICE is None:
|
||||
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
|
||||
self.args.device = utils.CUR_DEVICE
|
||||
else:
|
||||
self.args = args
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
|
@ -47,7 +72,10 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
with torch.device("meta"):
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
|
||||
if use_torch_npu and input_args.optimize_config_path is not None:
|
||||
optimize_config_path = input_args.optimize_config_path
|
||||
elif default_args.optimize_config_path is None:
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_config_path = args.optimize_config_path
|
||||
|
@ -60,7 +88,14 @@ class KTransformersInterface(TransformersInterface):
|
|||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
|
||||
if use_torch_npu:
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)
|
||||
#提前absorbed
|
||||
get_absort_weight(self.model, config)
|
||||
self.model.eval()
|
||||
else:
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
self.model.generation_config = generation_config
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
|
@ -77,9 +112,92 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
if use_torch_npu:
|
||||
self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype=torch.float16, device=self.args.device)
|
||||
self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype=torch.int32, device=self.args.device)
|
||||
self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype=torch.float16, device=self.args.device)
|
||||
self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device=self.args.device)
|
||||
self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device=self.args.device)
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_logits_to_token(self, logits: torch.Tensor):
|
||||
if self.model.generation_config.do_sample:
|
||||
logits = logits / self.temperature
|
||||
torch.manual_seed(0)
|
||||
probs = logits.view(1, self.model.config.vocab_size)
|
||||
sm = nn.Softmax(dim=-1)
|
||||
probs = sm(probs).half().npu()
|
||||
next_token = self.next_token_fake
|
||||
torch_npu._npu_topk_topp_sampling(probs, self.top_k, self.top_p, next_token, self.next_token_probs)
|
||||
last = next_token.squeeze(-1)
|
||||
else:
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
_, last = torch.topk(probs, k=1, dim=-1)
|
||||
last = last.item()
|
||||
self.ever_generated_ids.add(last)
|
||||
return last
|
||||
|
||||
def decode_one_tokens_npu(self):
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
torch.cuda.set_device(torch_device)
|
||||
if warm_uped and self.args.use_cuda_graph:
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner
|
||||
if check_runner(self.args.device):
|
||||
npu_graph_runner = get_or_create_runner(self.args.device)
|
||||
npu_graph_runner.init(self.args.batch_size, self.seq_length)
|
||||
self.cuda_graph_runner = npu_graph_runner
|
||||
utils._USE_NPU_GRAPH = True
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=self.args.device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(self.args.device)
|
||||
logits = self.cuda_graph_runner(
|
||||
inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.decode_logits_to_token(logits)
|
||||
|
||||
if self.args.use_cuda_graph:
|
||||
warm_uped = True
|
||||
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
self.current_ids.to(torch_device),
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
self.cache.change_seq_length(1)
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.decode_logits_to_token(logits)
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if use_torch_npu:
|
||||
return self.decode_one_tokens_npu()
|
||||
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
|
@ -127,9 +245,145 @@ class KTransformersInterface(TransformersInterface):
|
|||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill_npu(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
self.seq_length = input_ids_length
|
||||
return
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
device = self.args.device
|
||||
if is_new:
|
||||
self.ever_generated_ids.clear()
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.seq_length = 1
|
||||
|
||||
# flat_prev_ids = self.generated_ids.flatten()
|
||||
# for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
# if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
# same_prefix += 1
|
||||
# else:
|
||||
# break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.cache.position[0] = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
else:
|
||||
logger.warning(f"seq_length bigger than cache_lens, killed")
|
||||
exit(0)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
self.cache.position[0] = self.seq_length + 1
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
|
||||
def chunk_prefill(input_ids, cache_position):
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
return logits
|
||||
|
||||
logits = None
|
||||
def prefill_wrapper(prof=None):
|
||||
nonlocal logits
|
||||
chunk_start = 0
|
||||
while chunk_start < input_ids_length:
|
||||
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
|
||||
if self.cache != None:
|
||||
self.cache.cur_idx = cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
|
||||
chunk_start += self.args.chunk_size
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
if logits is None:
|
||||
raise ValueError('logits cannot be None')
|
||||
|
||||
|
||||
global WARM_UP_SKIP_CNT
|
||||
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
|
||||
if prof_prefill == "1":
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
prefill_wrapper(prof)
|
||||
else:
|
||||
prefill_wrapper()
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
|
||||
if use_torch_npu:
|
||||
return self.prefill_npu(self, input_ids, is_new, temperature, top_p, max_tokens, max_completion_tokens)
|
||||
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if max_tokens is not None:
|
||||
max_completion_tokens = max_tokens
|
||||
|
@ -144,6 +398,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
if use_torch_npu:
|
||||
device = self.args.device
|
||||
|
||||
if is_new:
|
||||
self.ever_generated_ids.clear()
|
||||
|
@ -159,16 +415,19 @@ class KTransformersInterface(TransformersInterface):
|
|||
)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
if not use_torch_npu:
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
@ -193,6 +452,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = self.seq_length + 1
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
if not (type(self) is TransformersInterface):
|
||||
|
@ -248,4 +509,18 @@ class KTransformersInterface(TransformersInterface):
|
|||
decode_time = self.profiler.get_timer_sec('decode'),
|
||||
prefill_count = self.profiler.get_counter('prefill'),
|
||||
decode_count = self.profiler.get_counter('decode'),
|
||||
)
|
||||
)
|
||||
|
||||
def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
async def run_async():
|
||||
result = []
|
||||
async for chunk in self.inference(local_messages, thread_id, temperature, top_p):
|
||||
pass
|
||||
return ""
|
||||
return loop.run_until_complete(run_async())
|
||||
finally:
|
||||
loop.close()
|
|
@ -32,6 +32,20 @@ from ktransformers.server.config.log import logger
|
|||
from ..args import ConfigArgs, default_args
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
class TextStreamer:
|
||||
|
||||
|
@ -191,11 +205,19 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
|
||||
if not use_torch_npu:
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
else:
|
||||
logger.debug(f"new_messages: {new_messages}")
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
new_messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
|
@ -212,6 +234,8 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
|
||||
self.generated_ids[0, self.seq_length] = new_tokens
|
||||
self.seq_length += 1
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = self.seq_length
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
@staticmethod
|
||||
|
@ -273,14 +297,21 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
top_p = self.model.generation_config.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
|
||||
if use_torch_npu:
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, do_sample=True,
|
||||
top_p=top_p, temperature=temperature
|
||||
)
|
||||
else:
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
|
||||
self.logits_warper = self.tf_logits_warper(generation_config)
|
||||
|
@ -372,7 +403,10 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
device = input_ids.device
|
||||
if use_torch_npu:
|
||||
device = self.args.device
|
||||
else:
|
||||
device = input_ids.device
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
|
@ -420,7 +454,12 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else: # for's else, if output get max new tokens
|
||||
yield self.streamer.end(), None
|
||||
yield "", "length"
|
||||
|
||||
|
||||
if use_torch_npu and self.args.use_cuda_graph:
|
||||
utils._USE_NPU_GRAPH = False
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(self.args.device)
|
||||
npu_graph_runner.destroy()
|
||||
|
||||
|
||||
def check_is_new(self, thread_id: str):
|
||||
|
@ -436,7 +475,87 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
|
||||
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
rank = torch.distributed.get_rank()
|
||||
tp_size = utils.get_tensor_parallel_size()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
if tp_size == world_size and tp_size > 1:
|
||||
torch.distributed.barrier()
|
||||
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
|
||||
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
|
||||
dist.all_gather(all_input_sizes, input_size)
|
||||
|
||||
max_input_size = max([size.item() for size in all_input_sizes])
|
||||
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
|
||||
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
|
||||
|
||||
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
|
||||
dist.all_gather(all_padded_inputs, padded_input_ids)
|
||||
|
||||
original_size = all_input_sizes[0].item()
|
||||
input_ids = all_padded_inputs[0][:, :original_size]
|
||||
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
print(think, end="",flush=True)
|
||||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
|
||||
if use_torch_npu:
|
||||
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
|
||||
yield tok
|
||||
return
|
||||
|
||||
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
|
|
|
@ -9,7 +9,7 @@ project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ktransformers.server.args import ArgumentParser
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
|
||||
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -17,6 +17,21 @@ from ktransformers.server.api import router, post_db_creation_operations
|
|||
from ktransformers.server.utils.sql_utils import Base, SQLUtil
|
||||
from ktransformers.server.config.log import logger
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
import torch.distributed
|
||||
import subprocess
|
||||
import tempfile
|
||||
import atexit
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
def mount_app_routes(mount_app: FastAPI):
|
||||
sql_util = SQLUtil()
|
||||
|
@ -100,6 +115,77 @@ def custom_openapi(app):
|
|||
return app.openapi_schema
|
||||
|
||||
|
||||
def main_npu():
|
||||
torch.npu.config.allow_internal_format = False
|
||||
cfg = Config()
|
||||
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
utils.USE_NPU_GRAPH = args.use_cuda_graph
|
||||
new_chunk_size = min(max(args.chunk_size, 512), utils._MAX_CHUNK_SIZE)
|
||||
if new_chunk_size != args.chunk_size:
|
||||
args.chunk_size = new_chunk_size
|
||||
print(f'[WARN] Chunk size reset to legal value between [512, {utils._MAX_CHUNK_SIZE}] which is {args.chunk_size}.')
|
||||
|
||||
if args.backend_type == "balance_serve":
|
||||
import pickle
|
||||
def cleanup():
|
||||
if sched_process.poll() is None:
|
||||
sched_process.terminate()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
pickle.dump(args, temp_file)
|
||||
temp_file_path = temp_file.name
|
||||
current_file = __file__
|
||||
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
|
||||
target_file = os.path.normpath(target_file)
|
||||
log_path = os.path.join(args.log_dir, "rpc.log")
|
||||
log = open(log_path, "a")
|
||||
sched_process = subprocess.Popen(
|
||||
["python3", target_file, "--config", temp_file_path],
|
||||
stdout=log,
|
||||
stderr=log
|
||||
)
|
||||
print("sched_rpc started with PID:", sched_process.pid)
|
||||
atexit.register(cleanup)
|
||||
create_interface(config=cfg, default_args=cfg, input_args=args)
|
||||
args.port += torch.distributed.get_rank()
|
||||
tp_size = utils.get_tensor_parallel_size()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if tp_size == world_size and tp_size > 1:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
app = create_app()
|
||||
custom_openapi(app)
|
||||
run_api(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
context = get_thread_context_manager()
|
||||
id = str(uuid4())
|
||||
context.interface.sync_inference("", id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
finally:
|
||||
pass
|
||||
else:
|
||||
app = create_app()
|
||||
custom_openapi(app)
|
||||
|
||||
run_api(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
|
||||
def main():
|
||||
cfg = Config()
|
||||
|
||||
|
@ -119,4 +205,7 @@ def main():
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if use_torch_npu:
|
||||
main_npu()
|
||||
else:
|
||||
main()
|
||||
|
|
|
@ -16,7 +16,7 @@ from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
|
|||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
|
||||
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
|
||||
|
||||
def create_interface(config: Config, default_args: ConfigArgs):
|
||||
def create_interface(config: Config, default_args: ConfigArgs, input_args=None):
|
||||
if config.backend_type=='transformers':
|
||||
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
|
||||
elif config.backend_type == 'exllamav2':
|
||||
|
@ -27,7 +27,12 @@ def create_interface(config: Config, default_args: ConfigArgs):
|
|||
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
|
||||
else:
|
||||
raise NotImplementedError(f'{config.backend_type} not implemented')
|
||||
GlobalInterface.interface = BackendInterface(default_args)
|
||||
|
||||
if config.backend_type == 'ktransformers':
|
||||
GlobalInterface.interface = BackendInterface(default_args, input_args)
|
||||
else:
|
||||
GlobalInterface.interface = BackendInterface(default_args)
|
||||
|
||||
GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)
|
||||
|
||||
class GlobalContextManager:
|
||||
|
|
210
ktransformers/util/ascend/ascend_utils.py
Normal file
210
ktransformers/util/ascend/ascend_utils.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
try:
|
||||
import torch_npu
|
||||
except:
|
||||
Warning("torch_npu not found, please install torch_npu for NPU support.")
|
||||
import torch.distributed as dist
|
||||
|
||||
_DATA_PARALLEL_SIZE = 0
|
||||
_TENSOR_PARALLEL_SIZE = 0
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
_TENSOR_PARALLEL_RANKS = None
|
||||
_TENSOR_PARALLEL_GROUP = None
|
||||
_DATA_PARALLEL_GROUP_GLOO = None
|
||||
_DATA_PARALLEL_RANKS = None
|
||||
_GLOBAL_GROUP = None
|
||||
_LM_HEAD_GROUP = None
|
||||
|
||||
|
||||
def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):
|
||||
global _DATA_PARALLEL_SIZE
|
||||
global _DATA_PARALLEL_GROUP
|
||||
global _DATA_PARALLEL_RANKS
|
||||
global _TENSOR_PARALLEL_SIZE
|
||||
global _TENSOR_PARALLEL_RANKS
|
||||
global _TENSOR_PARALLEL_GROUP
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
local_rank = int(os.getenv("LOCAL_RANK", '0'))
|
||||
world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||||
torch_npu.npu.set_device(local_rank)
|
||||
tp_size = tp
|
||||
dp_size = world_size // tp_size
|
||||
_DATA_PARALLEL_SIZE = dp_size
|
||||
_TENSOR_PARALLEL_SIZE = tp_size
|
||||
|
||||
torch.set_num_threads(8)
|
||||
timeout = timedelta(minutes=distributed_timeout_minutes)
|
||||
print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}")
|
||||
torch.distributed.init_process_group(
|
||||
backend='hccl',
|
||||
world_size=world_size, rank=local_rank
|
||||
)
|
||||
print(f"init process group success ------rank is {local_rank}, world_size is {world_size}")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
nccl_comm_cfgs = {}
|
||||
|
||||
for dp_group_id in range(tp_size):
|
||||
ranks = list(range(dp_group_id, world_size, tp_size))
|
||||
dp_group = torch.distributed.new_group(
|
||||
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
|
||||
)
|
||||
if rank in ranks:
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = dp_group
|
||||
_DATA_PARALLEL_RANKS = ranks
|
||||
|
||||
for tp_group_id in range(dp_size):
|
||||
start_rank = tp_group_id * tp_size
|
||||
end_rank = (tp_group_id + 1) * tp_size
|
||||
ranks = list(range(start_rank, end_rank))
|
||||
tp_group = torch.distributed.new_group(
|
||||
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
|
||||
)
|
||||
if rank in ranks:
|
||||
global _TENSOR_PARALLEL_GROUP
|
||||
_TENSOR_PARALLEL_GROUP = tp_group
|
||||
_TENSOR_PARALLEL_RANKS = ranks
|
||||
|
||||
torch.manual_seed(1)
|
||||
return local_rank, world_size
|
||||
|
||||
|
||||
def get_tensor_parallel_size():
|
||||
assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set"
|
||||
return _TENSOR_PARALLEL_SIZE
|
||||
|
||||
|
||||
def get_tensor_parallel_group():
|
||||
assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized"
|
||||
return _TENSOR_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_parallel_ranks():
|
||||
assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel ranks is not initialized"
|
||||
return _TENSOR_PARALLEL_RANKS
|
||||
|
||||
|
||||
def get_data_parallel_size():
|
||||
assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized"
|
||||
return _DATA_PARALLEL_SIZE
|
||||
|
||||
|
||||
def get_data_parallel_gloo():
|
||||
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized"
|
||||
return _DATA_PARALLEL_GROUP_GLOO
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_ranks():
|
||||
assert _DATA_PARALLEL_RANKS is not None, "data parallel ranks is not initialized"
|
||||
return _DATA_PARALLEL_RANKS
|
||||
|
||||
|
||||
def get_global_group():
|
||||
assert _GLOBAL_GROUP is not None, "global group is not initialized"
|
||||
return _GLOBAL_GROUP
|
||||
|
||||
|
||||
def get_nccl_options(pg_name, nccl_comm_cfgs):
|
||||
if pg_name in nccl_comm_cfgs:
|
||||
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
|
||||
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
|
||||
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
|
||||
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
|
||||
return nccl_options
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
|
||||
translate_col_cut_tensors = ["ffn_down", "attn_output"]
|
||||
translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"]
|
||||
translate_lm_cut_tensor = ["output"]
|
||||
tp = get_tensor_parallel_size()
|
||||
if tp == 1 or weights.shape == torch.Size([1]):
|
||||
return weights
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
rank %= tp
|
||||
assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}"
|
||||
if any(t in name for t in translate_col_cut_tensors):
|
||||
if weights.dim() == 1:
|
||||
return weights
|
||||
dim = weights.shape[-1]
|
||||
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
|
||||
chunk_size = dim // tp
|
||||
output_weights = weights[:, rank * chunk_size: (rank + 1) * chunk_size]
|
||||
return output_weights
|
||||
elif any(t in name for t in translate_row_cut_tensors):
|
||||
dim = weights.shape[0]
|
||||
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
|
||||
chunk_size = dim // tp
|
||||
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
|
||||
return output_weights
|
||||
elif (tp > 1) and (any(t in name for t in translate_lm_cut_tensor)):
|
||||
dim = weights.shape[0]
|
||||
assert dim % tp == 0, f"unexpected division {dim=} {world_size=}"
|
||||
chunk_size = dim // tp
|
||||
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
|
||||
return output_weights
|
||||
else:
|
||||
return weights
|
||||
|
||||
|
||||
def get_absort_weight(model, config):
|
||||
local_rank = torch.distributed.get_rank()
|
||||
tp = get_tensor_parallel_size()
|
||||
local_rank %= tp
|
||||
tp_heads = config.num_attention_heads // tp
|
||||
for i in range(config.num_hidden_layers):
|
||||
self = model.model.layers[i].self_attn
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
|
||||
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
|
||||
out_absorb = out_absorb.transpose(1, 2).contiguous()
|
||||
setattr(self, "q_absorb", q_absorb)
|
||||
setattr(self, "out_absorb", out_absorb)
|
||||
del self.orig_module.kv_b_proj
|
||||
|
||||
|
||||
def allreduce_wrapper(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
orig_output = func(*args, **kwargs)
|
||||
if isinstance(orig_output, tuple):
|
||||
if get_tensor_parallel_size() > 1:
|
||||
org_dtype = orig_output[0].dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,
|
||||
group=get_tensor_parallel_group())
|
||||
else:
|
||||
dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
|
||||
if org_dtype == torch.bfloat16:
|
||||
bf_orig_output = orig_output[0].to(dtype=org_dtype)
|
||||
else:
|
||||
bf_orig_output = orig_output[0]
|
||||
else:
|
||||
bf_orig_output = orig_output[0]
|
||||
return (bf_orig_output,) + orig_output[1:]
|
||||
else:
|
||||
if get_tensor_parallel_size() > 1:
|
||||
org_dtype = orig_output.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
orig_output = orig_output.to(dtype=torch.float16)
|
||||
dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
|
||||
if org_dtype == torch.bfloat16:
|
||||
orig_output = orig_output.to(dtype=org_dtype)
|
||||
return orig_output
|
||||
|
||||
return wrapper
|
|
@ -7,10 +7,19 @@ from typing import Sequence
|
|||
import os
|
||||
from enum import IntEnum
|
||||
import torch
|
||||
if not torch.xpu.is_available():
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
if not torch.xpu.is_available() and not use_torch_npu:
|
||||
import KTransformersOps
|
||||
from safetensors import safe_open
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
if not use_torch_npu:
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from ktransformers.util.custom_gguf import *
|
||||
from safetensors.torch import save_file
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -42,6 +51,7 @@ class SafeTensorLoader(ModelLoader):
|
|||
tensor_device_map: dict
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
|
||||
self.__load_tensor_file_map(file_path)
|
||||
|
||||
def __load_tensor_file_map(self, file_path: str):
|
||||
|
@ -84,6 +94,7 @@ class SafeTensorLoader(ModelLoader):
|
|||
# if not found_safetensor:
|
||||
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
|
||||
def load_tensor(self, key: str, device: str="cpu"):
|
||||
if translate_name_to_gguf(key) in self.tensor_file_map:
|
||||
key = translate_name_to_gguf(key)
|
||||
|
@ -96,6 +107,7 @@ class SafeTensorLoader(ModelLoader):
|
|||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, key: str, device: str="cpu"):
|
||||
|
@ -252,20 +264,57 @@ class SafeTensorLoader(ModelLoader):
|
|||
def has_tensor(self, name: str):
|
||||
return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
|
||||
|
||||
|
||||
class W8A8SafeTensorLoader(SafeTensorLoader):
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if "deq_scale" in key:
|
||||
tensor = torch.from_numpy(
|
||||
np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))
|
||||
if "input_scale" in key:
|
||||
tensor = tensor.to(torch.float16)
|
||||
if "weight_scale" in key or "weight_offset" in key:
|
||||
if "ffn" in key:
|
||||
tensor = tensor.to(torch.float32)
|
||||
else:
|
||||
tensor = tensor.to(torch.float16)
|
||||
if "input_offset" in key:
|
||||
tensor = tensor.to(torch.int8)
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.to(torch.float16)
|
||||
return tensor.to(device)
|
||||
|
||||
def load_dequantized_tensor(self, key: str, device: str = "cpu"):
|
||||
tensor = self.load_tensor(key, device)
|
||||
return tensor
|
||||
|
||||
class GGUFLoader(ModelLoader):
|
||||
tensor_info: dict
|
||||
gguf_path: str
|
||||
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
||||
gguf_file_meta: dict
|
||||
safetensor_loader: SafeTensorLoader
|
||||
def __init__(self, gguf_path: str):
|
||||
def __init__(self, gguf_path: str, quantize: str = None):
|
||||
# Check dir exist
|
||||
if not os.path.exists(gguf_path):
|
||||
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
|
||||
if os.path.isfile(gguf_path):
|
||||
gguf_path = os.path.dirname(gguf_path)
|
||||
|
||||
self.safetensor_loader = None
|
||||
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||
if quantize == "w8a8_dynamic":
|
||||
safetensor_loader = W8A8SafeTensorLoader(gguf_path)
|
||||
else:
|
||||
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||
if safetensor_loader.tensor_file_map:
|
||||
self.safetensor_loader = safetensor_loader
|
||||
return
|
||||
|
||||
self.tensor_info = {}
|
||||
self.gguf_path = gguf_path
|
||||
|
|
77
ktransformers/util/npu_graph.py
Normal file
77
ktransformers/util/npu_graph.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ktransformers.util.utils import USE_NPU_GRAPH
|
||||
if USE_NPU_GRAPH:
|
||||
CAPTURE_PLUGIN_PATH = os.environ.get("CAPTURE_PLUGIN_PATH")
|
||||
if CAPTURE_PLUGIN_PATH is None:
|
||||
raise RuntimeError("env CAPTURE_PLUGIN_PATH not exist")
|
||||
|
||||
sys.path.append(CAPTURE_PLUGIN_PATH)
|
||||
|
||||
from libgraph_capture import graph_capture_init
|
||||
from libgraph_capture import graph_capture_destroy
|
||||
from libgraph_capture import graph_capture_begin
|
||||
from libgraph_capture import graph_capture_end
|
||||
from libgraph_capture import graph_capture_replay
|
||||
from libgraph_capture import graph_capture_launch_callback
|
||||
|
||||
|
||||
class NpuGraph:
|
||||
def init(self):
|
||||
ret = graph_capture_init()
|
||||
if ret != 0:
|
||||
exit()
|
||||
|
||||
def destroy(self):
|
||||
ret = graph_capture_destroy()
|
||||
if ret != 0:
|
||||
exit()
|
||||
|
||||
def capture_begin(
|
||||
self,
|
||||
stream,
|
||||
capture_error_mode="global"):
|
||||
torch.npu.synchronize()
|
||||
torch.npu.empty_cache()
|
||||
ret = graph_capture_begin(stream, capture_error_mode)
|
||||
if ret != 0:
|
||||
exit()
|
||||
|
||||
def capture_end(
|
||||
self,
|
||||
stream):
|
||||
ret = graph_capture_end(stream)
|
||||
if ret != 0:
|
||||
exit()
|
||||
|
||||
def replay(
|
||||
self,
|
||||
stream):
|
||||
ret = graph_capture_replay(stream)
|
||||
if ret != 0:
|
||||
exit()
|
||||
|
||||
def launch_callback(self, func, data, block, stream):
|
||||
graph_capture_launch_callback(func, data, block, stream)
|
||||
|
||||
|
||||
class graph:
|
||||
def __init__(
|
||||
self,
|
||||
npu_graph: NpuGraph,
|
||||
pool,
|
||||
stream,
|
||||
capture_error_mode: str = "global"):
|
||||
self.npu_graph = npu_graph
|
||||
self.stream = stream.npu_stream
|
||||
|
||||
def __enter__(self):
|
||||
self.npu_graph.capture_begin(self.stream)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.npu_graph.capture_end(self.stream)
|
218
ktransformers/util/npu_graph_runner.py
Normal file
218
ktransformers/util/npu_graph_runner.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
'''
|
||||
Description :
|
||||
Author : Boxin Zhang
|
||||
Version : 0.1.0
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
from typing import Dict
|
||||
|
||||
import acl
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
|
||||
import ktransformers.util.npu_graph as npu_graph
|
||||
from ktransformers.util.utils import CUR_DEVICE
|
||||
|
||||
|
||||
class NPUGraphRunner:
|
||||
def __init__(self, deviceId):
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
self.deviceId = deviceId
|
||||
self.enable = False
|
||||
self.debug = False
|
||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.tid = None
|
||||
self.past_key_value = None
|
||||
|
||||
def init(self, batch_size, seq_length):
|
||||
self.tmp_g = npu_graph.NpuGraph()
|
||||
self.graph = torch.npu.NPUGraph()
|
||||
self.main_stream = torch_npu.npu.Stream(device=self.deviceId)
|
||||
self.update_stream = torch_npu.npu.Stream(device=self.deviceId)
|
||||
self.stream = self.main_stream.npu_stream
|
||||
self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId)
|
||||
self.context, ret = acl.rt.get_context(self.deviceId)
|
||||
if ret != 0:
|
||||
print("get_context failed! ret: " + str(ret))
|
||||
exit(-1)
|
||||
self.exit_flag = False
|
||||
self.handle = []
|
||||
self.ifa_param = []
|
||||
self.event = []
|
||||
self.first_update = True
|
||||
self.workspace = None
|
||||
|
||||
if self.tid is None:
|
||||
def process_callback(args_list):
|
||||
ins = args_list[0]
|
||||
ret = acl.rt.set_context(ins.context)
|
||||
if ret != 0:
|
||||
print("set_context failed! ret: " + str(ret))
|
||||
exit(-1)
|
||||
|
||||
while True:
|
||||
acl.rt.process_report(1)
|
||||
if ins.exit_flag:
|
||||
break
|
||||
|
||||
self.tid, ret = acl.util.start_thread(process_callback, [self])
|
||||
if ret != 0:
|
||||
print("start_thread failed!")
|
||||
exit(-1)
|
||||
|
||||
ret = acl.rt.subscribe_report(self.tid, self.stream)
|
||||
if ret != 0:
|
||||
print("subscribe_report failed!")
|
||||
exit(-1)
|
||||
|
||||
def destroy(self):
|
||||
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Begin -------------\n', end='')
|
||||
self.exit_flag = True
|
||||
ret = acl.rt.unsubscribe_report(self.tid, self.stream)
|
||||
if ret != 0:
|
||||
print("unsubscribe_report failed!")
|
||||
exit(-1)
|
||||
self.enable = False
|
||||
ret = acl.util.stop_thread(self.tid)
|
||||
if ret != 0:
|
||||
print("stop_thread failed!")
|
||||
exit(-1)
|
||||
self.tid = None
|
||||
self.workspace = None
|
||||
self.handle = []
|
||||
self.ifa_param = []
|
||||
self.event = []
|
||||
self.first_update = True
|
||||
del self.graph
|
||||
self.tmp_g.destroy()
|
||||
destroy_runner(self.deviceId)
|
||||
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Destroy Finish -------------\n', end='')
|
||||
|
||||
def capture(
|
||||
self,
|
||||
model,
|
||||
cur_token,
|
||||
position_ids,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
main_device,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Begin -------------\n', end='')
|
||||
self.enable = True
|
||||
self.model = model
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
|
||||
self.seq_length = inputs_embeds.size()[1]
|
||||
self.main_device = main_device
|
||||
with torch.no_grad():
|
||||
with torch.npu.graph(self.graph, stream=self.main_stream):
|
||||
self.logits = model(inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs)[0]
|
||||
|
||||
if past_key_values != None:
|
||||
past_key_values.change_seq_length(-1)
|
||||
|
||||
self.input_buffers = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
|
||||
self.output_buffers = {"logits": self.logits}
|
||||
print(f'[rank:{torch.distributed.get_rank()}]------------- NPU Graph Capture Finish -------------\n', end='')
|
||||
return
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cache_position,
|
||||
) -> torch.Tensor:
|
||||
def ifa_update_sync(param):
|
||||
with torch.npu.stream(self.update_stream):
|
||||
for i in range(len(self.handle)):
|
||||
if self.first_update is False:
|
||||
q_nope, kvCache, q_pe, kRopeCache, num_heads, \
|
||||
softmax_scale, layer_idx, attn_output, softmax_lse = self.ifa_param[i]
|
||||
torch.npu.graph_task_update_begin(self.update_stream, self.handle[i])
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
kvCache,
|
||||
kvCache,
|
||||
workspace=self.workspace,
|
||||
query_rope=q_pe,
|
||||
key_rope=kRopeCache,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=None,
|
||||
scale=softmax_scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=self.past_key_value.page_table_list[layer_idx],
|
||||
block_size=self.past_key_value.page_size,
|
||||
actual_seq_lengths_kv=self.past_key_value.position,
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(self.update_stream)
|
||||
self.event[i].record(self.update_stream)
|
||||
|
||||
self.ifa_update_tid, ret = acl.util.start_thread(ifa_update_sync, [self])
|
||||
if ret != 0:
|
||||
print("start_thread failed!")
|
||||
exit(-1)
|
||||
|
||||
ret1 = acl.rt.memcpy(self.input_buffers["inputs_embeds"].data_ptr(), inputs_embeds.numel() * 2,
|
||||
inputs_embeds.data_ptr(), inputs_embeds.numel() * 2, 3)
|
||||
ret2 = acl.rt.memcpy(self.input_buffers["position_ids"].data_ptr(), position_ids.numel() * 8,
|
||||
position_ids.data_ptr(), position_ids.numel() * 8, 3)
|
||||
ret3 = acl.rt.memcpy(self.input_buffers["cache_position"].data_ptr(), cache_position.numel() * 8,
|
||||
cache_position.data_ptr(), cache_position.numel() * 8, 3)
|
||||
torch_npu.npu.synchronize()
|
||||
|
||||
with torch_npu.npu.stream(self.main_stream):
|
||||
self.graph.replay()
|
||||
self.first_update = False
|
||||
ret = acl.util.stop_thread(self.ifa_update_tid)
|
||||
if ret != 0:
|
||||
print("stop_thread failed!")
|
||||
exit(-1)
|
||||
else:
|
||||
self.ifa_update_tid = None
|
||||
return self.output_buffers["logits"]
|
||||
|
||||
def launch_callback(self, func, data, block, stream):
|
||||
self.tmp_g.launch_callback(func, data, block, stream)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
runner_dict = dict()
|
||||
|
||||
|
||||
def check_runner(deviceId: int):
|
||||
runner = runner_dict.get(deviceId)
|
||||
if runner is None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def destroy_runner(deviceId: int):
|
||||
runner = runner_dict.get(deviceId)
|
||||
if runner is not None:
|
||||
runner_dict[deviceId] = None
|
||||
|
||||
|
||||
def get_or_create_runner(deviceId: int):
|
||||
runner = runner_dict.get(deviceId)
|
||||
|
||||
if runner is None:
|
||||
runner = NPUGraphRunner(deviceId)
|
||||
runner_dict[deviceId] = runner
|
||||
return runner
|
|
@ -31,8 +31,35 @@ if not torch.xpu.is_available():
|
|||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
import socket
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch.distributed as dist
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
||||
W8A8_ENABLE = False
|
||||
Q4_GGUF_LODER = None
|
||||
USE_NPU_GRAPH = None
|
||||
WARM_UP_SKIP_CNT = [1, 1]
|
||||
_USE_NPU_GRAPH = False
|
||||
_MAX_DECODE_PROFILE = 3
|
||||
CUR_DEVICE = None
|
||||
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
|
||||
|
||||
|
||||
def get_use_npu_graph():
|
||||
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
|
||||
return _USE_NPU_GRAPH
|
||||
|
||||
|
||||
def get_free_ports(n: int, continue_prot: list):
|
||||
sockets = []
|
||||
ports = []
|
||||
|
@ -50,6 +77,10 @@ def get_free_ports(n: int, continue_prot: list):
|
|||
return ports
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
|
||||
if use_torch_npu:
|
||||
return 0
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
@ -97,9 +128,16 @@ def get_all_used_cuda_device(device_map:dict):
|
|||
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
|
||||
if "cpu" in all_device_list:
|
||||
all_device_list.remove("cpu")
|
||||
|
||||
if use_torch_npu:
|
||||
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
|
||||
|
||||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
|
||||
|
||||
# TODO: support NPU
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
|
@ -109,6 +147,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
key = prefix + name
|
||||
translated_key = key
|
||||
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if isinstance(gguf_loader, SafeTensorLoader):
|
||||
|
@ -120,7 +159,13 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
|
||||
|
||||
if use_torch_npu:
|
||||
device = "cpu" if "embd" in translated_key else CUR_DEVICE
|
||||
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
|
||||
else:
|
||||
print(f"loading {translated_key} to {device}")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.xpu.is_available():
|
||||
|
@ -149,6 +194,8 @@ def sync_all_device(all_device_list):
|
|||
torch.cuda.synchronize(device)
|
||||
elif "xpu" in device.lower():
|
||||
torch.xpu.synchronize(device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.synchronize(device)
|
||||
else:
|
||||
raise RuntimeError("The device {} is not available".format(device))
|
||||
|
||||
|
@ -228,20 +275,68 @@ def tf_logits_warper(generation_config):
|
|||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
batch_size, seq_length = inputs.shape
|
||||
device_map = model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||
|
||||
if use_torch_npu:
|
||||
vocabulary_size = model.config.vocab_size
|
||||
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
|
||||
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
|
||||
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
|
||||
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
|
||||
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
|
||||
torch_device = CUR_DEVICE
|
||||
else:
|
||||
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||
inputs = inputs.to(torch_device)
|
||||
all_cuda_device = get_all_used_cuda_device(device_map)
|
||||
|
||||
tokens = []
|
||||
|
||||
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
|
||||
if cuda_graph_runner is None:
|
||||
use_cuda_graph = False
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
|
||||
if use_cuda_graph:
|
||||
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
|
||||
else:
|
||||
# custom_stream = torch.cuda.Stream()
|
||||
# torch.cuda.set_device(torch_device)
|
||||
torch_npu.npu.set_device(torch_device)
|
||||
# with torch.cuda.stream(custom_stream):
|
||||
logits=model(inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
if past_key_values != None:
|
||||
past_key_values.change_seq_length(1)
|
||||
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
|
||||
for device in all_cuda_device:
|
||||
# torch.cuda.synchronize(device)
|
||||
torch_npu.npu.synchronize(device)
|
||||
if generation_config.do_sample:
|
||||
logits = logits / temperature
|
||||
torch.manual_seed(0)
|
||||
probs = logits.view(batch_size, vocabulary_size)
|
||||
sm = nn.Softmax(dim=-1)
|
||||
probs = sm(probs).half().npu()
|
||||
next_token = next_token_fake
|
||||
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
|
||||
next_token = next_token.squeeze(-1)
|
||||
else:
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
return next_token
|
||||
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
|
||||
if use_torch_npu:
|
||||
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
|
||||
if cuda_graph_runner is None:
|
||||
use_cuda_graph = False
|
||||
if use_cuda_graph:
|
||||
|
@ -252,6 +347,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.set_device(torch_device)
|
||||
else:
|
||||
raise RuntimeError(f"The device: {torch_device} is not available")
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
||||
|
@ -279,6 +376,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
|
@ -288,11 +386,88 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
|
||||
global warm_uped
|
||||
global _USE_NPU_GRAPH
|
||||
if use_cuda_graph:
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
|
||||
npu_graph_runner.init(batch_size, seq_length)
|
||||
with torch_npu.npu.stream(npu_graph_runner.main_stream):
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
|
||||
torch.bfloat16)
|
||||
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
|
||||
warm_uped = True
|
||||
_USE_NPU_GRAPH = True
|
||||
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
cuda_graph_runner = npu_graph_runner
|
||||
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
|
||||
cache_position, past_key_values, logits_warper, generation_config,
|
||||
use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
|
||||
next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
past_key_values.position[0] += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
npu_graph_runner.destroy()
|
||||
_USE_NPU_GRAPH = False
|
||||
else:
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
|
||||
torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
|
||||
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
|
||||
torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
|
||||
next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
past_key_values.position[0] += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.set_device(torch_device)
|
||||
else:
|
||||
raise RuntimeError(f"The device: {torch_device} is not available")
|
||||
with torch.no_grad():
|
||||
|
@ -304,6 +479,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||
else:
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(None)
|
||||
elif use_torch_npu and static_cache:
|
||||
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
|
||||
past_key_values = static_cache
|
||||
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
|
||||
print('[WARN] current staticCache size exceeded, try create new staticCache...')
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values.reset()
|
||||
elif mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
|
@ -320,19 +505,67 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
logits_warper = tf_logits_warper(generation_config)
|
||||
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
if use_torch_npu:
|
||||
past_key_values.position[0] = seq_length + 1
|
||||
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
start_time = time.time()
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
logits = None
|
||||
|
||||
def prefill_wrapper(prof=None):
|
||||
nonlocal logits
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
if logits is None:
|
||||
raise ValueError('logits cannot be None')
|
||||
|
||||
if use_torch_npu:
|
||||
global WARM_UP_SKIP_CNT
|
||||
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
|
||||
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
prefill_wrapper(prof)
|
||||
else:
|
||||
prefill_wrapper()
|
||||
WARM_UP_SKIP_CNT[0] -= 1
|
||||
else:
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
@ -348,56 +581,106 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
elif not use_torch_npu:
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
|
||||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
if use_torch_npu:
|
||||
past_key_values.position += 1
|
||||
|
||||
cuda_graph_runner = None
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
|
||||
if not use_torch_npu:
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
else:
|
||||
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
|
||||
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
|
||||
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
|
||||
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
|
||||
else:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
|
||||
WARM_UP_SKIP_CNT[1] -= 1
|
||||
|
||||
|
||||
total_time = time.time() - start_time
|
||||
tokens_generated = len(tokens)
|
||||
tokens_per_second = tokens_generated / total_time
|
||||
|
||||
print("")
|
||||
if not use_torch_npu:
|
||||
print("")
|
||||
|
||||
print(f"prompt eval count: {prefill_count} token(s)")
|
||||
print(f"prompt eval duration: {prefill_time}s")
|
||||
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
|
||||
print(f"eval count: {tokens_generated} token(s)")
|
||||
print(f"eval duration: {total_time}s")
|
||||
print(f"eval rate: {tokens_per_second} tokens/s")
|
||||
else:
|
||||
tp_size = get_tensor_parallel_size()
|
||||
if torch.distributed.get_rank() % tp_size == 0:
|
||||
rank = f"[rank:{torch.distributed.get_rank()}]"
|
||||
msg = f"\n{rank} Eval Time\n"
|
||||
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
|
||||
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
|
||||
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
|
||||
msg += rank + f"eval count: {tokens_generated} token(s)\n"
|
||||
msg += rank + f"eval duration: {total_time:.9f}s\n"
|
||||
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
|
||||
print(msg)
|
||||
|
||||
print(f"prompt eval count: {prefill_count} token(s)")
|
||||
print(f"prompt eval duration: {prefill_time}s")
|
||||
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
|
||||
print(f"eval count: {tokens_generated} token(s)")
|
||||
print(f"eval duration: {total_time}s")
|
||||
print(f"eval rate: {tokens_per_second} tokens/s")
|
||||
|
||||
return tokens
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from safetensors.torch import save_file
|
|||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
SKIP_MTP = True
|
||||
|
||||
def read_safetensor_keys_from_folder(folder_path)->dict:
|
||||
"""
|
||||
:param folder_path: folder path
|
||||
|
@ -36,7 +38,7 @@ def read_safetensor_keys_from_folder(folder_path)->dict:
|
|||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
if "model.layers.61" in key:
|
||||
if SKIP_MTP and "model.layers.61" in key:
|
||||
# skip MTP layer
|
||||
continue
|
||||
# try:
|
||||
|
@ -94,6 +96,28 @@ def combine_tensor_sources(safetensor_path:str, gguf_path:str):
|
|||
|
||||
return target_tensor_map, gguf_loader
|
||||
|
||||
def combine_w8a8_tensor_sources(safetensor_path: str, gguf_path: str):
|
||||
gguf_loader = GGUFLoader(gguf_path)
|
||||
gguf_tensor_file_map = gguf_loader.tensor_file_map
|
||||
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
|
||||
|
||||
# build a map for the key to the tensor
|
||||
# according to the key, we can get the tensor from the file
|
||||
|
||||
target_tensor_map = {}
|
||||
for key in safetensor_tensor_file_map.keys():
|
||||
# for all experts, we use the gguf tensor
|
||||
if ".mlp.experts." in key and "weight_scale" not in key and "weight_offset" not in key:
|
||||
key = '.'.join(key.split('.')[:5] + key.split('.')[-2:])
|
||||
translated_key = translate_name(key)
|
||||
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
|
||||
elif ".mlp.experts." in key and ("weight_scale" not in key or "weight_offset" not in key):
|
||||
continue
|
||||
else:
|
||||
target_tensor_map[key] = safetensor_tensor_file_map[key]
|
||||
|
||||
return target_tensor_map, gguf_loader
|
||||
|
||||
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
@ -193,6 +217,7 @@ def main():
|
|||
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
|
||||
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
|
||||
parser.add_argument("--safetensors_format", type=str, help="Safetensors format", default="fp8")
|
||||
|
||||
# print all the arguments
|
||||
print("All the arguments:")
|
||||
|
@ -204,8 +229,18 @@ def main():
|
|||
safetensor_path = args.safetensor_path
|
||||
gguf_path = args.gguf_path
|
||||
output_path = args.output_path
|
||||
safetensors_format = args.safetensors_format
|
||||
|
||||
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||
match safetensors_format:
|
||||
case "w8a8":
|
||||
global SKIP_MTP
|
||||
SKIP_MTP = False
|
||||
target_tensor_map, gguf_loader = combine_w8a8_tensor_sources(safetensor_path, gguf_path)
|
||||
case "fp8":
|
||||
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported safetensors format: {safetensor_path}")
|
||||
|
||||
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
|
||||
|
||||
return
|
||||
|
|
32
setup.py
32
setup.py
|
@ -673,10 +673,29 @@ if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:
|
|||
ext_modules.append(
|
||||
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
||||
)
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
install_requires=triton_dep,
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=ext_modules
|
||||
)
|
||||
|
||||
|
||||
|
||||
elif torch.xpu.is_available():
|
||||
ext_modules = [
|
||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||
]
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
install_requires=triton_dep,
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=ext_modules
|
||||
)
|
||||
|
||||
elif KTRANSFORMERS_BUILD_NPU:
|
||||
ext_modules = [
|
||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||
|
@ -687,11 +706,10 @@ elif KTRANSFORMERS_BUILD_NPU:
|
|||
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
||||
)
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=ext_modules
|
||||
)
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
version=VersionInfo().get_package_version(),
|
||||
install_requires=triton_dep,
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=ext_modules
|
||||
)
|
||||
|
|
4930
third_party/llamafile/iqk_mul_mat.inc
vendored
4930
third_party/llamafile/iqk_mul_mat.inc
vendored
File diff suppressed because it is too large
Load diff
5866
third_party/llamafile/iqk_mul_mat_arm.inc
vendored
Normal file
5866
third_party/llamafile/iqk_mul_mat_arm.inc
vendored
Normal file
File diff suppressed because it is too large
Load diff
10
third_party/llamafile/iqk_mul_mat_arm80.cpp
vendored
Normal file
10
third_party/llamafile/iqk_mul_mat_arm80.cpp
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm80.cpp
|
||||
// Copyright 2024 Iwan Kawrakow.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
#ifdef __aarch64__
|
||||
#define iqk_mul_mat iqk_mul_mat_arm80
|
||||
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm80
|
||||
#include "iqk_mul_mat.inc"
|
||||
#endif // __aarch64__
|
4925
third_party/llamafile/iqk_mul_mat_x86.inc
vendored
Normal file
4925
third_party/llamafile/iqk_mul_mat_x86.inc
vendored
Normal file
File diff suppressed because it is too large
Load diff
209
third_party/llamafile/sgemm.cpp
vendored
209
third_party/llamafile/sgemm.cpp
vendored
|
@ -1,204 +1,7 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "sgemm.h"
|
||||
// #include <cosmo.h>
|
||||
// #include <cpuid.h>
|
||||
// #include <libc/sysv/consts/hwcap.h>
|
||||
#include <stdio.h>
|
||||
// #include <sys/auxv.h>
|
||||
#include <cassert>
|
||||
// #include "llamafile.h"
|
||||
|
||||
static const struct GemmFuncs {
|
||||
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
|
||||
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
|
||||
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
|
||||
// typeof(llamafile_sgemm)* sgemm;
|
||||
// typeof(llamafile_mixmul)* mixmul;
|
||||
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
|
||||
GemmFuncs() {
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
// if (X86_HAVE(AVX)) {
|
||||
// if (X86_HAVE(FMA)) {
|
||||
// if (X86_HAVE(AVX2)) {
|
||||
// if (X86_HAVE(AVX512F)) {
|
||||
// if (X86_HAVE(AVX512VL) && //
|
||||
// X86_HAVE(AVX512BW) && //
|
||||
// X86_HAVE(AVX512DQ) && //
|
||||
// X86_HAVE(AVX512_VNNI) && //
|
||||
// X86_HAVE(AVX512_BF16)) {
|
||||
// // AMD Zen4+ (2023-)
|
||||
// sgemm = llamafile_sgemm_amd_zen4;
|
||||
// mixmul = llamafile_mixmul_amd_zen4;
|
||||
// iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
// } else {
|
||||
// // Intel Xeon Skylake+ (2015-)
|
||||
// sgemm = llamafile_sgemm_amd_avx512f;
|
||||
// mixmul = llamafile_mixmul_amd_avx512f;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else if (X86_HAVE(AVXVNNI)) {
|
||||
// // Intel Alderlake (2021-)
|
||||
// sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
// mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// } else {
|
||||
// // Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// // AMD Excavator (2015-2022)
|
||||
// sgemm = llamafile_sgemm_amd_avx2;
|
||||
// mixmul = llamafile_mixmul_amd_avx2;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD Piledriver (2011-2014)
|
||||
// sgemm = llamafile_sgemm_amd_fma;
|
||||
// mixmul = llamafile_mixmul_amd_fma;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// // AMD Bulldozer (2011)
|
||||
// sgemm = llamafile_sgemm_amd_avx;
|
||||
// mixmul = llamafile_mixmul_amd_avx;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD K8/Barcelona (2003-2010)
|
||||
// // Intel Core/Nehalem (2006-2009)
|
||||
// sgemm = llamafile_sgemm_unsupported;
|
||||
// mixmul = llamafile_mixmul_unsupported;
|
||||
// }
|
||||
|
||||
#if defined(__AVX__)
|
||||
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
|
||||
#if defined(__AVX2__)
|
||||
#if defined(__AVX512F__)
|
||||
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
|
||||
// AMD Zen4+ (2023-)
|
||||
sgemm = llamafile_sgemm_amd_zen4;
|
||||
mixmul = llamafile_mixmul_amd_zen4;
|
||||
iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
|
||||
// 使用 x86 版本
|
||||
#include "sgemm_arm.cpp"
|
||||
#else
|
||||
// Intel Xeon Skylake+ (2015-)
|
||||
sgemm = llamafile_sgemm_amd_avx512f;
|
||||
mixmul = llamafile_mixmul_amd_avx512f;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#elif defined(__AVXVNNI__)
|
||||
// Intel Alderlake (2021-)
|
||||
sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#else
|
||||
// Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// AMD Excavator (2015-2022)
|
||||
sgemm = llamafile_sgemm_amd_avx2;
|
||||
mixmul = llamafile_mixmul_amd_avx2;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// AMD Piledriver (2011-2014)
|
||||
sgemm = llamafile_sgemm_amd_fma;
|
||||
mixmul = llamafile_mixmul_amd_fma;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// AMD Bulldozer (2011)
|
||||
sgemm = llamafile_sgemm_amd_avx;
|
||||
mixmul = llamafile_mixmul_amd_avx;
|
||||
#endif
|
||||
#else
|
||||
// AMD K8/Barcelona (2003-2010)
|
||||
// Intel Core/Nehalem (2006-2009)
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
|
||||
#elif defined(__aarch64__)
|
||||
long hwcap = getauxval(AT_HWCAP);
|
||||
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
|
||||
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
|
||||
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
|
||||
// e.g. Apple M1, Raspberry Pi 5
|
||||
sgemm = llamafile_sgemm_arm82;
|
||||
mixmul = llamafile_mixmul_arm82;
|
||||
iqk_mixmul = iqk_mul_mat_moe_arm82;
|
||||
} else {
|
||||
// ARM64 baseline ISA
|
||||
sgemm = llamafile_sgemm_arm80;
|
||||
mixmul = llamafile_mixmul_arm80;
|
||||
}
|
||||
#else
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
}
|
||||
} funcs;
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param task is GGML task type
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
|
||||
precision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs "mixture of experts" tensor multiplication on CPU.
|
||||
*/
|
||||
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
|
||||
return funcs.mixmul(params, weights, thought, plan, result);
|
||||
}
|
||||
|
||||
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
|
||||
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
|
||||
}
|
||||
// 使用 ARM 版本
|
||||
#include "sgemm_x86.cpp"
|
||||
#endif
|
204
third_party/llamafile/sgemm_arm.cpp
vendored
Normal file
204
third_party/llamafile/sgemm_arm.cpp
vendored
Normal file
|
@ -0,0 +1,204 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "sgemm.h"
|
||||
// #include <cosmo.h>
|
||||
// #include <cpuid.h>
|
||||
// #include <libc/sysv/consts/hwcap.h>
|
||||
#include <stdio.h>
|
||||
// #include <sys/auxv.h>
|
||||
#include <cassert>
|
||||
// #include "llamafile.h"
|
||||
|
||||
static const struct GemmFuncs {
|
||||
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
|
||||
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
|
||||
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
|
||||
// typeof(llamafile_sgemm)* sgemm;
|
||||
// typeof(llamafile_mixmul)* mixmul;
|
||||
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
|
||||
GemmFuncs() {
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
// if (X86_HAVE(AVX)) {
|
||||
// if (X86_HAVE(FMA)) {
|
||||
// if (X86_HAVE(AVX2)) {
|
||||
// if (X86_HAVE(AVX512F)) {
|
||||
// if (X86_HAVE(AVX512VL) && //
|
||||
// X86_HAVE(AVX512BW) && //
|
||||
// X86_HAVE(AVX512DQ) && //
|
||||
// X86_HAVE(AVX512_VNNI) && //
|
||||
// X86_HAVE(AVX512_BF16)) {
|
||||
// // AMD Zen4+ (2023-)
|
||||
// sgemm = llamafile_sgemm_amd_zen4;
|
||||
// mixmul = llamafile_mixmul_amd_zen4;
|
||||
// iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
// } else {
|
||||
// // Intel Xeon Skylake+ (2015-)
|
||||
// sgemm = llamafile_sgemm_amd_avx512f;
|
||||
// mixmul = llamafile_mixmul_amd_avx512f;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else if (X86_HAVE(AVXVNNI)) {
|
||||
// // Intel Alderlake (2021-)
|
||||
// sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
// mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// } else {
|
||||
// // Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// // AMD Excavator (2015-2022)
|
||||
// sgemm = llamafile_sgemm_amd_avx2;
|
||||
// mixmul = llamafile_mixmul_amd_avx2;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD Piledriver (2011-2014)
|
||||
// sgemm = llamafile_sgemm_amd_fma;
|
||||
// mixmul = llamafile_mixmul_amd_fma;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// // AMD Bulldozer (2011)
|
||||
// sgemm = llamafile_sgemm_amd_avx;
|
||||
// mixmul = llamafile_mixmul_amd_avx;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD K8/Barcelona (2003-2010)
|
||||
// // Intel Core/Nehalem (2006-2009)
|
||||
// sgemm = llamafile_sgemm_unsupported;
|
||||
// mixmul = llamafile_mixmul_unsupported;
|
||||
// }
|
||||
|
||||
#if defined(__AVX__)
|
||||
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
|
||||
#if defined(__AVX2__)
|
||||
#if defined(__AVX512F__)
|
||||
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
|
||||
// AMD Zen4+ (2023-)
|
||||
sgemm = llamafile_sgemm_amd_zen4;
|
||||
mixmul = llamafile_mixmul_amd_zen4;
|
||||
iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
#else
|
||||
// Intel Xeon Skylake+ (2015-)
|
||||
sgemm = llamafile_sgemm_amd_avx512f;
|
||||
mixmul = llamafile_mixmul_amd_avx512f;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#elif defined(__AVXVNNI__)
|
||||
// Intel Alderlake (2021-)
|
||||
sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#else
|
||||
// Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// AMD Excavator (2015-2022)
|
||||
sgemm = llamafile_sgemm_amd_avx2;
|
||||
mixmul = llamafile_mixmul_amd_avx2;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// AMD Piledriver (2011-2014)
|
||||
sgemm = llamafile_sgemm_amd_fma;
|
||||
mixmul = llamafile_mixmul_amd_fma;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// AMD Bulldozer (2011)
|
||||
sgemm = llamafile_sgemm_amd_avx;
|
||||
mixmul = llamafile_mixmul_amd_avx;
|
||||
#endif
|
||||
#else
|
||||
// AMD K8/Barcelona (2003-2010)
|
||||
// Intel Core/Nehalem (2006-2009)
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
|
||||
#elif defined(__aarch64__)
|
||||
// long hwcap = getauxval(AT_HWCAP);
|
||||
// if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
|
||||
// (hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
|
||||
// (hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
|
||||
// // e.g. Apple M1, Raspberry Pi 5
|
||||
// sgemm = llamafile_sgemm_arm82;
|
||||
// mixmul = llamafile_mixmul_arm82;
|
||||
// iqk_mixmul = iqk_mul_mat_moe_arm82;
|
||||
// } else {
|
||||
// ARM64 baseline ISA
|
||||
sgemm = llamafile_sgemm_arm80;
|
||||
mixmul = llamafile_mixmul_arm80;
|
||||
// }
|
||||
#else
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
}
|
||||
} funcs;
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param task is GGML task type
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
|
||||
precision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs "mixture of experts" tensor multiplication on CPU.
|
||||
*/
|
||||
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
|
||||
return funcs.mixmul(params, weights, thought, plan, result);
|
||||
}
|
||||
|
||||
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
|
||||
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
|
||||
}
|
204
third_party/llamafile/sgemm_x86.cpp
vendored
Normal file
204
third_party/llamafile/sgemm_x86.cpp
vendored
Normal file
|
@ -0,0 +1,204 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "sgemm.h"
|
||||
// #include <cosmo.h>
|
||||
// #include <cpuid.h>
|
||||
// #include <libc/sysv/consts/hwcap.h>
|
||||
#include <stdio.h>
|
||||
// #include <sys/auxv.h>
|
||||
#include <cassert>
|
||||
// #include "llamafile.h"
|
||||
|
||||
static const struct GemmFuncs {
|
||||
bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
|
||||
bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
|
||||
bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
|
||||
// typeof(llamafile_sgemm)* sgemm;
|
||||
// typeof(llamafile_mixmul)* mixmul;
|
||||
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
|
||||
GemmFuncs() {
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
// if (X86_HAVE(AVX)) {
|
||||
// if (X86_HAVE(FMA)) {
|
||||
// if (X86_HAVE(AVX2)) {
|
||||
// if (X86_HAVE(AVX512F)) {
|
||||
// if (X86_HAVE(AVX512VL) && //
|
||||
// X86_HAVE(AVX512BW) && //
|
||||
// X86_HAVE(AVX512DQ) && //
|
||||
// X86_HAVE(AVX512_VNNI) && //
|
||||
// X86_HAVE(AVX512_BF16)) {
|
||||
// // AMD Zen4+ (2023-)
|
||||
// sgemm = llamafile_sgemm_amd_zen4;
|
||||
// mixmul = llamafile_mixmul_amd_zen4;
|
||||
// iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
// } else {
|
||||
// // Intel Xeon Skylake+ (2015-)
|
||||
// sgemm = llamafile_sgemm_amd_avx512f;
|
||||
// mixmul = llamafile_mixmul_amd_avx512f;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else if (X86_HAVE(AVXVNNI)) {
|
||||
// // Intel Alderlake (2021-)
|
||||
// sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
// mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// } else {
|
||||
// // Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// // AMD Excavator (2015-2022)
|
||||
// sgemm = llamafile_sgemm_amd_avx2;
|
||||
// mixmul = llamafile_mixmul_amd_avx2;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD Piledriver (2011-2014)
|
||||
// sgemm = llamafile_sgemm_amd_fma;
|
||||
// mixmul = llamafile_mixmul_amd_fma;
|
||||
// if (X86_HAVE(F16C))
|
||||
// iqk_mixmul = iqk_mul_mat_moe;
|
||||
// }
|
||||
// } else {
|
||||
// // Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// // AMD Bulldozer (2011)
|
||||
// sgemm = llamafile_sgemm_amd_avx;
|
||||
// mixmul = llamafile_mixmul_amd_avx;
|
||||
// }
|
||||
// } else {
|
||||
// // AMD K8/Barcelona (2003-2010)
|
||||
// // Intel Core/Nehalem (2006-2009)
|
||||
// sgemm = llamafile_sgemm_unsupported;
|
||||
// mixmul = llamafile_mixmul_unsupported;
|
||||
// }
|
||||
|
||||
#if defined(__AVX__)
|
||||
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
|
||||
#if defined(__AVX2__)
|
||||
#if defined(__AVX512F__)
|
||||
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
|
||||
// AMD Zen4+ (2023-)
|
||||
sgemm = llamafile_sgemm_amd_zen4;
|
||||
mixmul = llamafile_mixmul_amd_zen4;
|
||||
iqk_mixmul = iqk_mul_mat_moe_zen4;
|
||||
#else
|
||||
// Intel Xeon Skylake+ (2015-)
|
||||
sgemm = llamafile_sgemm_amd_avx512f;
|
||||
mixmul = llamafile_mixmul_amd_avx512f;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#elif defined(__AVXVNNI__)
|
||||
// Intel Alderlake (2021-)
|
||||
sgemm = llamafile_sgemm_amd_avxvnni;
|
||||
mixmul = llamafile_mixmul_amd_avxvnni;
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#else
|
||||
// Intel Haswell/Broadwell/Skylake (2013-2020)
|
||||
// AMD Excavator (2015-2022)
|
||||
sgemm = llamafile_sgemm_amd_avx2;
|
||||
mixmul = llamafile_mixmul_amd_avx2;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// AMD Piledriver (2011-2014)
|
||||
sgemm = llamafile_sgemm_amd_fma;
|
||||
mixmul = llamafile_mixmul_amd_fma;
|
||||
#if defined(__F16C__)
|
||||
iqk_mixmul = iqk_mul_mat_moe;
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// Intel Sandybridge/Ivybridge (2010-2012)
|
||||
// AMD Bulldozer (2011)
|
||||
sgemm = llamafile_sgemm_amd_avx;
|
||||
mixmul = llamafile_mixmul_amd_avx;
|
||||
#endif
|
||||
#else
|
||||
// AMD K8/Barcelona (2003-2010)
|
||||
// Intel Core/Nehalem (2006-2009)
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
|
||||
#elif defined(__aarch64__)
|
||||
long hwcap = getauxval(AT_HWCAP);
|
||||
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
|
||||
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
|
||||
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
|
||||
// e.g. Apple M1, Raspberry Pi 5
|
||||
sgemm = llamafile_sgemm_arm82;
|
||||
mixmul = llamafile_mixmul_arm82;
|
||||
iqk_mixmul = iqk_mul_mat_moe_arm82;
|
||||
} else {
|
||||
// ARM64 baseline ISA
|
||||
sgemm = llamafile_sgemm_arm80;
|
||||
mixmul = llamafile_mixmul_arm80;
|
||||
}
|
||||
#else
|
||||
sgemm = llamafile_sgemm_unsupported;
|
||||
mixmul = llamafile_mixmul_unsupported;
|
||||
#endif
|
||||
}
|
||||
} funcs;
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param task is GGML task type
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
|
||||
precision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs "mixture of experts" tensor multiplication on CPU.
|
||||
*/
|
||||
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
|
||||
return funcs.mixmul(params, weights, thought, plan, result);
|
||||
}
|
||||
|
||||
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
|
||||
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#ifdef __aarch64__
|
||||
#define llamafile_mixmul llamafile_mixmul_arm80
|
||||
#define iqk_mul_mat iqk_mul_mat_arm80
|
||||
#include "tinyblas_cpu_mixmul.inc"
|
||||
|
||||
/**
|
||||
|
|
366
third_party/llamafile/tinyblas_cpu_sgemm.inc
vendored
366
third_party/llamafile/tinyblas_cpu_sgemm.inc
vendored
|
@ -1,361 +1,7 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tinyblas_cpu.h"
|
||||
|
||||
//
|
||||
//
|
||||
// ██████╗ ██╗ █████╗ ██████╗
|
||||
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
|
||||
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
|
||||
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
|
||||
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
|
||||
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
|
||||
//
|
||||
// BASIC LINEAR ALGEBRA SUBPROGRAMS
|
||||
//
|
||||
//
|
||||
// This file implements multithreaded CPU matrix multiplication for the
|
||||
// common contiguous use case C = Aᵀ * B. These kernels are designed to
|
||||
// have excellent performance[1] for matrices that fit in the CPU cache
|
||||
// without imposing any overhead such as cache filling or malloc calls.
|
||||
//
|
||||
// This implementation does not guarantee any upper bound with rounding
|
||||
// errors, which grow along with k. Our goal's to maximally exploit the
|
||||
// hardware for performance, and then use whatever resources remain for
|
||||
// improving numerical accuracy.
|
||||
//
|
||||
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
||||
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TC>
|
||||
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
switch (Atype) {
|
||||
case GGML_TYPE_F32: {
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX__) || defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU
|
||||
// 使用 x86 版本
|
||||
#include "tinyblas_cpu_sgemm_arm.inc"
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_BF16: {
|
||||
#if defined(__AVX512BF16__)
|
||||
if (k % 32)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_BF16)
|
||||
return NOT_SUPPORTED;
|
||||
if (!FLAG_precise) {
|
||||
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_F16: {
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
||||
// if (X86_CHECK(F16C)) {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
// } else {
|
||||
// return NOT_SUPPORTED;
|
||||
// }
|
||||
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise)
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
return NOT_SUPPORTED;
|
||||
if (precision == GGML_PREC_F32) {
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise)
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
return NOT_SUPPORTED;
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q4_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
(void)A;
|
||||
(void)lda;
|
||||
(void)B;
|
||||
(void)ldb;
|
||||
(void)C;
|
||||
(void)ldc;
|
||||
(void)ith;
|
||||
(void)nth;
|
||||
(void)Atype;
|
||||
(void)Btype;
|
||||
(void)precision;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* For example, for single-threaded single-precision GEMM you can say
|
||||
*
|
||||
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
|
||||
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
* GGML_PREC_DEFAULT);
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
assert(m >= 0);
|
||||
assert(n >= 0);
|
||||
assert(k >= 0);
|
||||
assert(lda >= k);
|
||||
assert(ldb >= k);
|
||||
assert(ldc >= m);
|
||||
assert(nth > 0);
|
||||
assert(ith < nth);
|
||||
|
||||
#if QK_K == 256
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
|
||||
/*
|
||||
moonll
|
||||
more Btype accept
|
||||
}*/
|
||||
|
||||
if (Ctype == GGML_TYPE_F32){
|
||||
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
|
||||
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
|
||||
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
|
||||
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
|
||||
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
|
||||
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
switch (Ctype) {
|
||||
case GGML_TYPE_F32:
|
||||
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
|
||||
Btype, Ctype, precision);
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
}
|
||||
// 使用 ARM 版本
|
||||
#include "tinyblas_cpu_sgemm_x86.inc"
|
||||
#endif
|
471
third_party/llamafile/tinyblas_cpu_sgemm_arm.inc
vendored
Normal file
471
third_party/llamafile/tinyblas_cpu_sgemm_arm.inc
vendored
Normal file
|
@ -0,0 +1,471 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tinyblas_cpu.h"
|
||||
#include <arm_neon.h>
|
||||
#include <ostream>
|
||||
#include <iostream>
|
||||
//
|
||||
//
|
||||
// ██████╗ ██╗ █████╗ ██████╗
|
||||
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
|
||||
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
|
||||
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
|
||||
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
|
||||
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
|
||||
//
|
||||
// BASIC LINEAR ALGEBRA SUBPROGRAMS
|
||||
//
|
||||
//
|
||||
// This file implements multithreaded CPU matrix multiplication for the
|
||||
// common contiguous use case C = Aᵀ * B. These kernels are designed to
|
||||
// have excellent performance[1] for matrices that fit in the CPU cache
|
||||
// without imposing any overhead such as cache filling or malloc calls.
|
||||
//
|
||||
// This implementation does not guarantee any upper bound with rounding
|
||||
// errors, which grow along with k. Our goal's to maximally exploit the
|
||||
// hardware for performance, and then use whatever resources remain for
|
||||
// improving numerical accuracy.
|
||||
//
|
||||
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
||||
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TC>
|
||||
void SgemmHelperN1Neon2(long m, long n, long k, const float16_t* A, long lda, const float16_t* B, long ldb,
|
||||
TC* C, long ldc, int ith, int nth) {
|
||||
// A m * k B n * k c n * m
|
||||
const long NVL = 8;
|
||||
long kk = k / (NVL * 4);
|
||||
kk = kk * (NVL * 4);
|
||||
long length = (m / nth) + (ith < (m % nth) ? 1 : 0);
|
||||
long startRow = ith * (m / nth) + (ith < (m % nth) ? ith : (m % nth));
|
||||
long endRow = startRow + length;
|
||||
for (long i = startRow; i < endRow; i ++) {
|
||||
const float16_t* tA = A + i * lda;
|
||||
float32x4_t c0 = vdupq_n_f32(0);
|
||||
float32x4_t c1 = vdupq_n_f32(0);
|
||||
float32x4_t c2 = vdupq_n_f32(0);
|
||||
float32x4_t c3 = vdupq_n_f32(0);
|
||||
float32x4_t c4 = vdupq_n_f32(0);
|
||||
float32x4_t c5 = vdupq_n_f32(0);
|
||||
float32x4_t c6 = vdupq_n_f32(0);
|
||||
float32x4_t c7 = vdupq_n_f32(0);
|
||||
for (long j = 0; j < kk; j += NVL * 4) {
|
||||
__builtin_prefetch(tA + 192, 0, 0);
|
||||
float16x8_t a0 = vld1q_f16(tA + j);
|
||||
float16x8_t b0 = vld1q_f16(B + j);
|
||||
c0 = vfmlalq_low_f16(c0, a0, b0);
|
||||
c1 = vfmlalq_high_f16(c1, a0, b0);
|
||||
float16x8_t a1 = vld1q_f16(tA + j + NVL);
|
||||
float16x8_t b1 = vld1q_f16(B + j + NVL);
|
||||
c2 = vfmlalq_low_f16(c2, a1, b1);
|
||||
c3 = vfmlalq_high_f16(c3, a1, b1);
|
||||
float16x8_t a2 = vld1q_f16(tA + j + NVL * 2);
|
||||
float16x8_t b2 = vld1q_f16(B + j + NVL * 2);
|
||||
c4 = vfmlalq_low_f16(c4, a2, b2);
|
||||
c5 = vfmlalq_high_f16(c5, a2, b2);
|
||||
float16x8_t a3 = vld1q_f16(tA + j + NVL * 3);
|
||||
float16x8_t b3 = vld1q_f16(B + j + NVL * 3);
|
||||
c6 = vfmlalq_low_f16(c6, a3, b3);
|
||||
c7 = vfmlalq_high_f16(c7, a3, b3);
|
||||
}
|
||||
if (k - kk >= NVL * 2) {
|
||||
float16x8_t a0 = vld1q_f16(tA + kk);
|
||||
float16x8_t b0 = vld1q_f16(B + kk);
|
||||
c0 = vfmlalq_low_f16(c0, a0, b0);
|
||||
c1 = vfmlalq_high_f16(c1, a0, b0);
|
||||
float16x8_t a1 = vld1q_f16(tA + kk + NVL);
|
||||
float16x8_t b1 = vld1q_f16(B + kk + NVL);
|
||||
c2 = vfmlalq_low_f16(c2, a1, b1);
|
||||
c3 = vfmlalq_high_f16(c3, a1, b1);
|
||||
kk += NVL * 2;
|
||||
}
|
||||
if (k - kk >= NVL) {
|
||||
float16x8_t a = vld1q_f16(tA + kk);
|
||||
float16x8_t b = vld1q_f16(B + kk);
|
||||
c0 = vfmlalq_low_f16(c0, a, b);
|
||||
c1 = vfmlalq_high_f16(c1, a, b);
|
||||
kk += NVL;
|
||||
}
|
||||
TC sum = 0.0f;
|
||||
for (long j = kk; j < k; j ++) {
|
||||
sum += (float32_t)tA[j] * (float32_t)B[j];
|
||||
}
|
||||
c0 = vaddq_f32(c0, c1);
|
||||
c2 = vaddq_f32(c2, c3);
|
||||
c4 = vaddq_f32(c4, c5);
|
||||
c6 = vaddq_f32(c6, c7);
|
||||
c0 = vaddq_f32(c0, c2);
|
||||
c4 = vaddq_f32(c4, c6);
|
||||
sum += vaddvq_f32(c0) + vaddvq_f32(c4);
|
||||
C[i] = sum;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename TC>
|
||||
void SgemmHelperN1(long m, long n, long k, const ggml_fp16_t* A_, long lda, const ggml_fp16_t* B_, long ldb,
|
||||
TC* C, long ldc, int ith, int nth) {
|
||||
// A m * k B n * k c n * m
|
||||
float16_t *A = (float16_t*)A_;
|
||||
float16_t *B = (float16_t*)B_;
|
||||
long rowsPerThread = m / nth;
|
||||
long startRow = ith * rowsPerThread;
|
||||
long endRow = (ith == nth - 1) ? m : startRow + rowsPerThread;
|
||||
for (long i = startRow; i < endRow; i ++) {
|
||||
TC sum = 0.0f;
|
||||
for (long j = 0; j < k; j ++) {
|
||||
sum += (float32_t)A[i * lda + j] * (float32_t)B[j];
|
||||
}
|
||||
C[i] = sum;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename TC>
|
||||
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
switch (Atype) {
|
||||
case GGML_TYPE_F32: {
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX__) || defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_BF16: {
|
||||
#if defined(__AVX512BF16__)
|
||||
if (k % 32)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_BF16)
|
||||
return NOT_SUPPORTED;
|
||||
if (!FLAG_precise) {
|
||||
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_F16: {
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
||||
// if (X86_CHECK(F16C)) {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
// } else {
|
||||
// return NOT_SUPPORTED;
|
||||
// }
|
||||
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise) {
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
|
||||
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
|
||||
return true;
|
||||
}
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
if (precision == GGML_PREC_F32) {
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise) {
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {
|
||||
SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);
|
||||
return true;
|
||||
}
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q4_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
(void)A;
|
||||
(void)lda;
|
||||
(void)B;
|
||||
(void)ldb;
|
||||
(void)C;
|
||||
(void)ldc;
|
||||
(void)ith;
|
||||
(void)nth;
|
||||
(void)Atype;
|
||||
(void)Btype;
|
||||
(void)precision;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* For example, for single-threaded single-precision GEMM you can say
|
||||
*
|
||||
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
|
||||
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
* GGML_PREC_DEFAULT);
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
assert(m >= 0);
|
||||
assert(n >= 0);
|
||||
assert(k >= 0);
|
||||
assert(lda >= k);
|
||||
assert(ldb >= k);
|
||||
assert(ldc >= m);
|
||||
assert(nth > 0);
|
||||
assert(ith < nth);
|
||||
|
||||
#if QK_K == 256
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
|
||||
/*
|
||||
moonll
|
||||
more Btype accept
|
||||
}*/
|
||||
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
|
||||
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32){
|
||||
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
|
||||
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
|
||||
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
|
||||
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
|
||||
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
|
||||
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
|
||||
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
|
||||
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
switch (Ctype) {
|
||||
case GGML_TYPE_F32:
|
||||
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
|
||||
Btype, Ctype, precision);
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
}
|
361
third_party/llamafile/tinyblas_cpu_sgemm_x86.inc
vendored
Normal file
361
third_party/llamafile/tinyblas_cpu_sgemm_x86.inc
vendored
Normal file
|
@ -0,0 +1,361 @@
|
|||
// Adapted from
|
||||
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
|
||||
// Copyrigth 2024 Mozilla Foundation.
|
||||
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
|
||||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tinyblas_cpu.h"
|
||||
|
||||
//
|
||||
//
|
||||
// ██████╗ ██╗ █████╗ ██████╗
|
||||
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
|
||||
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
|
||||
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
|
||||
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
|
||||
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
|
||||
//
|
||||
// BASIC LINEAR ALGEBRA SUBPROGRAMS
|
||||
//
|
||||
//
|
||||
// This file implements multithreaded CPU matrix multiplication for the
|
||||
// common contiguous use case C = Aᵀ * B. These kernels are designed to
|
||||
// have excellent performance[1] for matrices that fit in the CPU cache
|
||||
// without imposing any overhead such as cache filling or malloc calls.
|
||||
//
|
||||
// This implementation does not guarantee any upper bound with rounding
|
||||
// errors, which grow along with k. Our goal's to maximally exploit the
|
||||
// hardware for performance, and then use whatever resources remain for
|
||||
// improving numerical accuracy.
|
||||
//
|
||||
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
||||
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TC>
|
||||
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
switch (Atype) {
|
||||
case GGML_TYPE_F32: {
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX__) || defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
|
||||
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_BF16: {
|
||||
#if defined(__AVX512BF16__)
|
||||
if (k % 32)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_BF16)
|
||||
return NOT_SUPPORTED;
|
||||
if (!FLAG_precise) {
|
||||
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__AVX2__)
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
|
||||
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_F16: {
|
||||
#if defined(__AVX512F__)
|
||||
if (k % 16)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
||||
// if (X86_CHECK(F16C)) {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32 && n < 2) {
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
// } else {
|
||||
// return NOT_SUPPORTED;
|
||||
// }
|
||||
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise)
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
return NOT_SUPPORTED;
|
||||
if (precision == GGML_PREC_F32) {
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
} else {
|
||||
if (k % 8)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_F16)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (n < 2 && !FLAG_precise)
|
||||
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
|
||||
return NOT_SUPPORTED;
|
||||
if (k % 4)
|
||||
return NOT_SUPPORTED;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return NOT_SUPPORTED;
|
||||
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
|
||||
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
|
||||
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
case GGML_TYPE_Q4_0: {
|
||||
if (Btype == GGML_TYPE_F32)
|
||||
return WANT_QUANTIZATION;
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return NOT_SUPPORTED;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
|
||||
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
return true;
|
||||
#else
|
||||
return NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
(void)A;
|
||||
(void)lda;
|
||||
(void)B;
|
||||
(void)ldb;
|
||||
(void)C;
|
||||
(void)ldc;
|
||||
(void)ith;
|
||||
(void)nth;
|
||||
(void)Atype;
|
||||
(void)Btype;
|
||||
(void)precision;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* Performs optimized matrix multiplication on CPU.
|
||||
*
|
||||
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
||||
* Despite its name, this isn't a generalized implementation. Work is
|
||||
* only performed when a handwritten kernel is written and available.
|
||||
* Otherwise the caller should fall back to a general matmul routine.
|
||||
*
|
||||
* For example, for single-threaded single-precision GEMM you can say
|
||||
*
|
||||
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
|
||||
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
* GGML_PREC_DEFAULT);
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
* @param n is cols in `B` and `C`
|
||||
* @param k is cols in `A` and rows in `B`
|
||||
* @param A is first input matrix (always transposed)
|
||||
* @param lda is row stride of `A`
|
||||
* @param B is second input matrix (never transposed)
|
||||
* @param ldb is row stride of `B`
|
||||
* @param C is input/output array of output matrices
|
||||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @param precision may be used to control the internal compute type
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
|
||||
assert(m >= 0);
|
||||
assert(n >= 0);
|
||||
assert(k >= 0);
|
||||
assert(lda >= k);
|
||||
assert(ldb >= k);
|
||||
assert(ldc >= m);
|
||||
assert(nth > 0);
|
||||
assert(ith < nth);
|
||||
|
||||
#if QK_K == 256
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
|
||||
/*
|
||||
moonll
|
||||
more Btype accept
|
||||
}*/
|
||||
|
||||
if (Ctype == GGML_TYPE_F32){
|
||||
if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
|
||||
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
|
||||
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
|
||||
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
|
||||
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
|
||||
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
switch (Ctype) {
|
||||
case GGML_TYPE_F32:
|
||||
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
|
||||
Btype, Ctype, precision);
|
||||
default:
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue