mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/labeler.yml # .github/workflows/server.yml # .gitignore # CMakeLists.txt # Makefile # README-sycl.md # README.md # llama.cpp # requirements/requirements-convert-hf-to-gguf-update.txt # requirements/requirements-convert-hf-to-gguf.txt # requirements/requirements-convert-legacy-llama.txt # scripts/sync-ggml.last # tests/test-tokenizer-random.py
This commit is contained in:
commit
92afdfcae4
44 changed files with 10304 additions and 8631 deletions
|
@ -11,9 +11,21 @@
|
||||||
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
|
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "sycl-base",
|
||||||
|
"hidden": true,
|
||||||
|
"generator": "Ninja",
|
||||||
|
"binaryDir": "${sourceDir}/build-${presetName}",
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
|
||||||
|
"CMAKE_CXX_COMPILER": "icx",
|
||||||
|
"LLAMA_SYCL": "ON",
|
||||||
|
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
|
||||||
|
}
|
||||||
|
},
|
||||||
{ "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } },
|
{ "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } },
|
||||||
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
|
||||||
|
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
||||||
{ "name": "static", "hidden": true, "cacheVariables": { "LLAMA_STATIC": "ON" } },
|
{ "name": "static", "hidden": true, "cacheVariables": { "LLAMA_STATIC": "ON" } },
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -35,15 +47,18 @@
|
||||||
},
|
},
|
||||||
|
|
||||||
{ "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
|
{ "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
|
||||||
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "release" ] },
|
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
|
||||||
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "release", "static" ] },
|
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] },
|
||||||
|
|
||||||
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
|
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
|
||||||
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "release" ] },
|
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
|
||||||
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] },
|
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
|
||||||
|
|
||||||
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] },
|
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] },
|
||||||
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "release" ] },
|
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
|
||||||
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "release", "static" ] }
|
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
|
||||||
|
|
||||||
|
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
|
||||||
|
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -5,7 +5,7 @@ default: koboldcpp_default koboldcpp_failsafe koboldcpp_openblas koboldcpp_noavx
|
||||||
tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip whispermain sdmain gguf-split
|
tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip whispermain sdmain gguf-split
|
||||||
dev: koboldcpp_openblas
|
dev: koboldcpp_openblas
|
||||||
dev2: koboldcpp_clblast
|
dev2: koboldcpp_clblast
|
||||||
|
dev3: koboldcpp_vulkan
|
||||||
|
|
||||||
ifndef UNAME_S
|
ifndef UNAME_S
|
||||||
UNAME_S := $(shell uname -s)
|
UNAME_S := $(shell uname -s)
|
||||||
|
@ -158,7 +158,7 @@ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instanc
|
||||||
|
|
||||||
ifdef LLAMA_CUBLAS
|
ifdef LLAMA_CUBLAS
|
||||||
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||||
CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/local/cuda/targets/sbsa-linux/lib -L/usr/lib/wsl/lib
|
CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/local/cuda/targets/sbsa-linux/lib -L/usr/lib/wsl/lib
|
||||||
CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
|
CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
|
||||||
CUBLAS_OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
CUBLAS_OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
||||||
CUBLAS_OBJS += $(OBJS_CUDA_TEMP_INST)
|
CUBLAS_OBJS += $(OBJS_CUDA_TEMP_INST)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
|
@ -543,6 +542,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||||
|
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||||
else { invalid_param = true; }
|
else { invalid_param = true; }
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1871,6 +1871,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||||
|
|
||||||
options.push_back({ "backend" });
|
options.push_back({ "backend" });
|
||||||
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
||||||
|
|
||||||
if (llama_supports_mlock()) {
|
if (llama_supports_mlock()) {
|
||||||
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
|
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
|
||||||
}
|
}
|
||||||
|
@ -2658,7 +2659,14 @@ static bool llama_download_file(const std::string & url, const std::string & pat
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the output file
|
// Set the output file
|
||||||
std::unique_ptr<FILE, decltype(&fclose)> outfile(fopen(path_temporary.c_str(), "wb"), fclose);
|
|
||||||
|
struct FILE_deleter {
|
||||||
|
void operator()(FILE * f) const {
|
||||||
|
fclose(f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
||||||
if (!outfile) {
|
if (!outfile) {
|
||||||
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -69,7 +69,6 @@ struct gpt_params {
|
||||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
|
||||||
int32_t grp_attn_n = 1; // group-attention factor
|
int32_t grp_attn_n = 1; // group-attention factor
|
||||||
int32_t grp_attn_w = 512; // group-attention width
|
int32_t grp_attn_w = 512; // group-attention width
|
||||||
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
||||||
|
|
|
@ -214,7 +214,7 @@ src_func = f"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
convert_py_pth = pathlib.Path("convert-hf-to-gguf.py")
|
convert_py_pth = pathlib.Path("convert-hf-to-gguf.py")
|
||||||
convert_py = convert_py_pth.read_text()
|
convert_py = convert_py_pth.read_text(encoding="utf-8")
|
||||||
convert_py = re.sub(
|
convert_py = re.sub(
|
||||||
r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)",
|
r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)",
|
||||||
lambda m: m.group(1) + src_func + m.group(3),
|
lambda m: m.group(1) + src_func + m.group(3),
|
||||||
|
@ -222,7 +222,7 @@ convert_py = re.sub(
|
||||||
flags=re.DOTALL | re.MULTILINE,
|
flags=re.DOTALL | re.MULTILINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_py_pth.write_text(convert_py)
|
convert_py_pth.write_text(convert_py, encoding="utf-8")
|
||||||
|
|
||||||
logger.info("+++ convert-hf-to-gguf.py was updated")
|
logger.info("+++ convert-hf-to-gguf.py was updated")
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
||||||
return lines;
|
return lines;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
size_t n_tokens = tokens.size();
|
||||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,13 +42,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
|
|
||||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
if (embd == NULL) {
|
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
|
||||||
if (embd == NULL) {
|
|
||||||
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||||
//TODO: I would also add a parameter here to enable normalization or not.
|
//TODO: I would also add a parameter here to enable normalization or not.
|
||||||
|
@ -98,6 +93,12 @@ int main(int argc, char ** argv) {
|
||||||
const int n_ctx_train = llama_n_ctx_train(model);
|
const int n_ctx_train = llama_n_ctx_train(model);
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||||
|
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_ctx > n_ctx_train) {
|
if (n_ctx > n_ctx_train) {
|
||||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||||
__func__, n_ctx_train, n_ctx);
|
__func__, n_ctx_train, n_ctx);
|
||||||
|
|
|
@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
|
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
llama_set_embeddings(ctx, true);
|
||||||
llama_set_causal_attn(ctx, false);
|
llama_set_causal_attn(ctx, false);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
|
@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
llama_token eos_token = llama_token_eos(mdl);
|
llama_token eos_token = llama_token_eos(mdl);
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
llama_set_embeddings(ctx, false);
|
||||||
llama_set_causal_attn(ctx, true);
|
llama_set_causal_attn(ctx, true);
|
||||||
|
|
||||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||||
|
|
||||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||||
|
@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {
|
||||||
|
|
||||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||||
|
|
||||||
// create new context - set to embedding mode
|
// create generation context
|
||||||
cparams.embeddings = true;
|
|
||||||
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
||||||
|
|
||||||
// ### Embedding/Representation ###
|
// ### Embedding/Representation ###
|
||||||
|
|
|
@ -224,7 +224,11 @@ int main(int argc, char ** argv) {
|
||||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
||||||
embd_inp = inp_pfx;
|
embd_inp = inp_pfx;
|
||||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||||
embd_inp.push_back(llama_token_middle(model));
|
|
||||||
|
const llama_token middle_token = llama_token_middle(model);
|
||||||
|
if (middle_token >= 0) {
|
||||||
|
embd_inp.push_back(middle_token);
|
||||||
|
}
|
||||||
|
|
||||||
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
|
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
|
||||||
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
|
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
|
||||||
|
@ -529,7 +533,12 @@ int main(int argc, char ** argv) {
|
||||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
||||||
embd_inp = inp_pfx;
|
embd_inp = inp_pfx;
|
||||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||||
embd_inp.push_back(llama_token_middle(model));
|
|
||||||
|
const llama_token middle_token = llama_token_middle(model);
|
||||||
|
if (middle_token >= 0) {
|
||||||
|
embd_inp.push_back(middle_token);
|
||||||
|
}
|
||||||
|
|
||||||
embd.clear();
|
embd.clear();
|
||||||
n_remain = params.n_predict;
|
n_remain = params.n_predict;
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
|
|
@ -131,23 +131,30 @@ class LlamaState: ObservableObject {
|
||||||
|
|
||||||
messageLog += "\(text)"
|
messageLog += "\(text)"
|
||||||
|
|
||||||
|
Task.detached {
|
||||||
while await llamaContext.n_cur < llamaContext.n_len {
|
while await llamaContext.n_cur < llamaContext.n_len {
|
||||||
let result = await llamaContext.completion_loop()
|
let result = await llamaContext.completion_loop()
|
||||||
messageLog += "\(result)"
|
await MainActor.run {
|
||||||
|
self.messageLog += "\(result)"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let t_end = DispatchTime.now().uptimeNanoseconds
|
let t_end = DispatchTime.now().uptimeNanoseconds
|
||||||
let t_generation = Double(t_end - t_heat_end) / NS_PER_S
|
let t_generation = Double(t_end - t_heat_end) / self.NS_PER_S
|
||||||
let tokens_per_second = Double(await llamaContext.n_len) / t_generation
|
let tokens_per_second = Double(await llamaContext.n_len) / t_generation
|
||||||
|
|
||||||
await llamaContext.clear()
|
await llamaContext.clear()
|
||||||
messageLog += """
|
|
||||||
|
await MainActor.run {
|
||||||
|
self.messageLog += """
|
||||||
\n
|
\n
|
||||||
Done
|
Done
|
||||||
Heat up took \(t_heat)s
|
Heat up took \(t_heat)s
|
||||||
Generated \(tokens_per_second) t/s\n
|
Generated \(tokens_per_second) t/s\n
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func bench() async {
|
func bench() async {
|
||||||
guard let llamaContext else {
|
guard let llamaContext else {
|
||||||
|
|
|
@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
||||||
return chunks;
|
return chunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
size_t n_tokens = tokens.size();
|
||||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,6 +161,12 @@ int main(int argc, char ** argv) {
|
||||||
const int n_ctx_train = llama_n_ctx_train(model);
|
const int n_ctx_train = llama_n_ctx_train(model);
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||||
|
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_ctx > n_ctx_train) {
|
if (n_ctx > n_ctx_train) {
|
||||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||||
__func__, n_ctx_train, n_ctx);
|
__func__, n_ctx_train, n_ctx);
|
||||||
|
|
|
@ -1595,7 +1595,7 @@ struct server_context {
|
||||||
} else {
|
} else {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
||||||
json_value(task.data, "prompt", std::string());
|
prompt = json_value(task.data, "prompt", std::string());
|
||||||
}
|
}
|
||||||
|
|
||||||
slot = get_available_slot(prompt);
|
slot = get_available_slot(prompt);
|
||||||
|
@ -2039,7 +2039,12 @@ struct server_context {
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
||||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
||||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||||
prefix_tokens.push_back(llama_token_middle(model));
|
|
||||||
|
const llama_token middle_token = llama_token_middle(model);
|
||||||
|
if (middle_token >= 0) {
|
||||||
|
prefix_tokens.push_back(middle_token);
|
||||||
|
}
|
||||||
|
|
||||||
prompt_tokens = prefix_tokens;
|
prompt_tokens = prefix_tokens;
|
||||||
} else {
|
} else {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
|
|
|
@ -13,16 +13,16 @@ if %errorlevel% neq 0 goto ERROR
|
||||||
|
|
||||||
:: for FP16
|
:: for FP16
|
||||||
:: faster for long-prompt inference
|
:: faster for long-prompt inference
|
||||||
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
|
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
|
||||||
|
|
||||||
:: for FP32
|
:: for FP32
|
||||||
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
|
cmake -G "Ninja" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
|
||||||
if %errorlevel% neq 0 goto ERROR
|
if %errorlevel% neq 0 goto ERROR
|
||||||
:: build example/main only
|
:: build example/main only
|
||||||
:: make main
|
:: make main
|
||||||
|
|
||||||
:: build all binary
|
:: build all binary
|
||||||
make -j
|
cmake --build . -j
|
||||||
if %errorlevel% neq 0 goto ERROR
|
if %errorlevel% neq 0 goto ERROR
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -1706,14 +1706,16 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||||
static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
||||||
bool backend_ids_changed = false;
|
bool backend_ids_changed = false;
|
||||||
for (int i = 0; i < sched->graph->n_nodes; i++) {
|
for (int i = 0; i < sched->graph->n_nodes; i++) {
|
||||||
if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i]) {
|
if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
|
||||||
|
sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
|
||||||
backend_ids_changed = true;
|
backend_ids_changed = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!backend_ids_changed) {
|
if (!backend_ids_changed) {
|
||||||
for (int i = 0; i < sched->graph->n_leafs; i++) {
|
for (int i = 0; i < sched->graph->n_leafs; i++) {
|
||||||
if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i]) {
|
if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
|
||||||
|
sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
|
||||||
backend_ids_changed = true;
|
backend_ids_changed = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1977,6 +1979,15 @@ int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
|
||||||
return sched->n_copies;
|
return sched->n_copies;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
|
||||||
|
return sched->n_backends;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
|
||||||
|
GGML_ASSERT(i >= 0 && i < sched->n_backends);
|
||||||
|
return sched->backends[i];
|
||||||
|
}
|
||||||
|
|
||||||
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||||
|
|
|
@ -182,6 +182,9 @@ extern "C" {
|
||||||
// Initialize backend buffers from a measure graph
|
// Initialize backend buffers from a measure graph
|
||||||
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
|
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
|
||||||
|
|
||||||
|
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
|
||||||
|
|
||||||
// Get the number of splits of the last graph
|
// Get the number of splits of the last graph
|
||||||
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
||||||
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
|
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
|
||||||
|
|
|
@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
|
||||||
}
|
}
|
||||||
|
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
|
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
|
||||||
}
|
}
|
||||||
return row_rounding;
|
return row_rounding;
|
||||||
}
|
}
|
||||||
|
|
|
@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Round rows to this value for --split-mode row:
|
// Round rows to this value for --split-mode row:
|
||||||
static int get_mmq_y_host(const int cc, const int mmq_x) {
|
static int get_mmq_y_host(const int cc) {
|
||||||
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
|
return cc >= CC_VOLTA ? 128 : 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////
|
//////////////////////
|
||||||
|
|
|
@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
||||||
|
#define MMQ_NWARPS 8
|
||||||
|
|
||||||
typedef void (*load_tiles_mmq_t)(
|
typedef void (*load_tiles_mmq_t)(
|
||||||
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
||||||
|
@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)(
|
||||||
typedef void (*vec_dot_mmq_t)(
|
typedef void (*vec_dot_mmq_t)(
|
||||||
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
||||||
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
|
||||||
|
|
||||||
struct block_q8_1_mmq {
|
struct block_q8_1_mmq {
|
||||||
half2 ds[4];
|
half2 ds[4];
|
||||||
|
@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
||||||
|
|
||||||
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
|
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
|
||||||
|
|
||||||
|
static constexpr __device__ int get_mmq_y_device() {
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
static constexpr __device__ int get_mmq_y_device(int mmq_x) {
|
return 128;
|
||||||
return mmq_x >= 32 ? 128 : 64;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
static constexpr __device__ int get_mmq_y_device(int mmq_x) {
|
return 128;
|
||||||
return mmq_x >= 32 ? 128 : 64;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
|
|
||||||
return 64;
|
return 64;
|
||||||
}
|
|
||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
}
|
||||||
|
|
||||||
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
||||||
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
||||||
|
@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||||
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
||||||
|
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
if (j >= ne1) {
|
if (j > j_max) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
if (need_check && i >= ne0) {
|
if (need_check && i > i_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||||
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
||||||
|
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||||
|
|
||||||
typedef mma_int_C_I16J8 mma_C;
|
typedef mma_int_C_I16J8 mma_C;
|
||||||
|
|
||||||
const int i0 = threadIdx.y*mma_C::I;
|
const int i0 = threadIdx.y*mma_C::I;
|
||||||
|
@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < mma_C::ne; ++l) {
|
||||||
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
|
const int j = j0 + mma_C::get_j(l);
|
||||||
|
|
||||||
if (j >= ne1) {
|
if (j > j_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
|
const int i = i0 + mma_C::get_i(l);
|
||||||
|
|
||||||
if (need_check && i >= ne0) {
|
if (need_check && i > i_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
static __device__ void mul_mat_q_process_tile(
|
||||||
#if defined(RDNA3) || defined(RDNA2)
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
|
||||||
#endif // defined(RDNA3) || defined(RDNA2)
|
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
|
||||||
#else
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
|
||||||
#else
|
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
|
||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
static __global__ void mul_mat_q(
|
|
||||||
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
|
|
||||||
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
|
||||||
|
|
||||||
// Skip unused template specializations for faster compilation:
|
|
||||||
if (mmq_x > get_mmq_x_max_device()) {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
||||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
||||||
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
||||||
|
|
||||||
|
@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q(
|
||||||
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
|
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
|
||||||
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
||||||
|
|
||||||
const int blocks_per_row_x = ne00 / qk;
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
const int blocks_per_warp = WARP_SIZE / qi;
|
|
||||||
|
|
||||||
const int & ne1 = ne11;
|
|
||||||
|
|
||||||
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
|
|
||||||
|
|
||||||
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
||||||
|
|
||||||
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||||
|
|
||||||
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
const int tile_x_max_i = ne01 - it*mmq_y - 1;
|
||||||
|
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
|
||||||
|
|
||||||
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
|
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||||
|
|
||||||
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
|
||||||
|
|
||||||
|
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int kr = 0; kr < qr; ++kr) {
|
for (int kr = 0; kr < qr; ++kr) {
|
||||||
|
@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
write_back(sum, dst, ne0, ne1);
|
if (fixup) {
|
||||||
|
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
|
||||||
|
} else {
|
||||||
|
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#else
|
||||||
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
#else
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
static __global__ void mul_mat_q(
|
||||||
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||||
|
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
||||||
|
|
||||||
|
// Skip unused template specializations for faster compilation:
|
||||||
|
if (mmq_x > get_mmq_x_max_device()) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
|
|
||||||
|
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
||||||
|
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
||||||
|
{
|
||||||
|
constexpr bool fixup = false;
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
blockIdx.x, blockIdx.y, 0, ne00/qk);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
||||||
|
|
||||||
|
const int64_t blocks_per_ne00 = ne00 / qk;
|
||||||
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
|
|
||||||
|
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||||
|
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||||
|
|
||||||
|
// kbc == k block continuous, current index in continuous ijk space.
|
||||||
|
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
|
||||||
|
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
|
||||||
|
|
||||||
|
// kb0 == k index when doing the matrix multiplication for an output tile.
|
||||||
|
int kb0_start = kbc % blocks_per_ne00;
|
||||||
|
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
||||||
|
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
||||||
|
const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
|
||||||
|
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
|
||||||
|
|
||||||
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
it, jt, kb0_start, kb0_stop);
|
||||||
|
|
||||||
|
kbc += blocks_per_ne00;
|
||||||
|
kbc -= kbc % blocks_per_ne00;
|
||||||
|
|
||||||
|
kb0_start = 0;
|
||||||
|
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kbc >= kbc_stop) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int jt = kbc / (blocks_per_ne00*nty);
|
||||||
|
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
||||||
|
|
||||||
|
constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
it, jt, kb0_start, kb0_stop);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
static __global__ void mul_mat_q_stream_k_fixup(
|
||||||
|
float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
|
||||||
|
|
||||||
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
|
const int64_t blocks_per_ne00 = ne00 / qk;
|
||||||
|
|
||||||
|
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||||
|
|
||||||
|
const int ntx = (ne11 + mmq_x - 1) / mmq_x;
|
||||||
|
const int nty = (ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
|
||||||
|
bool any_fixup = false;
|
||||||
|
|
||||||
|
const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
|
||||||
|
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
|
||||||
|
|
||||||
|
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
|
||||||
|
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
|
||||||
|
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
|
||||||
|
|
||||||
|
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
|
||||||
|
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int jt = kbc_stop / (blocks_per_ne00*nty);
|
||||||
|
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
||||||
|
|
||||||
|
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
|
||||||
|
if (it != blockIdx.x || jt != blockIdx.y) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
any_fixup = true;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!any_fixup) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
|
||||||
|
|
||||||
|
const int i_max = ne01 - blockIdx.x*mmq_y - 1;
|
||||||
|
const int j_max = ne11 - blockIdx.y*mmq_x - 1;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j > j_max) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
if (need_check && i > i_max) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct mmq_args {
|
struct mmq_args {
|
||||||
|
@ -1987,124 +2139,151 @@ struct mmq_args {
|
||||||
int64_t ne0;
|
int64_t ne0;
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr int mmq_get_nwarps(int mmq_x) {
|
|
||||||
return mmq_x >= 32 ? 8 : 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
|
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
|
||||||
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
||||||
const int nwarps = mmq_get_nwarps(mmq_x);
|
|
||||||
|
|
||||||
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||||
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
||||||
return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
|
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps>
|
template <ggml_type type, int mmq_x>
|
||||||
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int mmq_y = get_mmq_y_host(cc, mmq_x);
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
|
|
||||||
const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
|
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
||||||
const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
|
|
||||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
|
||||||
|
|
||||||
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
|
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||||
if (!shmem_limit_raised[id]) {
|
if (!shmem_limit_raised[id]) {
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||||
shmem_limit_raised[id] = true;
|
shmem_limit_raised[id] = true;
|
||||||
}
|
}
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
|
||||||
|
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
|
||||||
|
const dim3 block_nums_xy_tiling(nty, ntx, 1);
|
||||||
|
|
||||||
|
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
|
||||||
|
if (!use_stream_k) {
|
||||||
if (args.ne01 % mmq_y == 0) {
|
if (args.ne01 % mmq_y == 0) {
|
||||||
const bool need_check = false;
|
constexpr bool need_check = false;
|
||||||
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
|
||||||
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
} else {
|
} else {
|
||||||
const bool need_check = true;
|
constexpr bool need_check = true;
|
||||||
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
|
||||||
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_nums_mmq(nsm, 1, 1);
|
||||||
|
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
|
||||||
|
|
||||||
|
if (args.ne01 % mmq_y == 0) {
|
||||||
|
constexpr bool need_check = false;
|
||||||
|
|
||||||
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
|
||||||
|
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
|
||||||
|
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
|
||||||
|
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
|
||||||
|
} else {
|
||||||
|
constexpr bool need_check = true;
|
||||||
|
|
||||||
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
|
||||||
|
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
|
||||||
|
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
|
||||||
|
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type>
|
template <ggml_type type>
|
||||||
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||||
const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
|
||||||
|
|
||||||
int mmq_x_best = 0;
|
int mmq_x_best = 0;
|
||||||
int nwaves_best = INT_MAX;
|
int nparts_best = INT_MAX;
|
||||||
|
|
||||||
for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
|
||||||
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
||||||
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
|
const int nwaves_xy_tiling = ntiles_x*block_num_y;
|
||||||
|
|
||||||
if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
|
const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
|
||||||
|
|
||||||
|
if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
|
||||||
mmq_x_best = mmq_x;
|
mmq_x_best = mmq_x;
|
||||||
nwaves_best = nwaves;
|
nparts_best = nparts;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (mmq_x_best) {
|
switch (mmq_x_best) {
|
||||||
case 8:
|
case 8:
|
||||||
launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
|
launch_mul_mat_q<type, 8>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
|
launch_mul_mat_q<type, 16>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 24:
|
case 24:
|
||||||
launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
|
launch_mul_mat_q<type, 24>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
|
launch_mul_mat_q<type, 32>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 40:
|
case 40:
|
||||||
launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
|
launch_mul_mat_q<type, 40>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 48:
|
case 48:
|
||||||
launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
|
launch_mul_mat_q<type, 48>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 56:
|
case 56:
|
||||||
launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
|
launch_mul_mat_q<type, 56>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
|
launch_mul_mat_q<type, 64>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 72:
|
case 72:
|
||||||
launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
|
launch_mul_mat_q<type, 72>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
|
launch_mul_mat_q<type, 80>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 88:
|
case 88:
|
||||||
launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
|
launch_mul_mat_q<type, 88>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 96:
|
case 96:
|
||||||
launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
|
launch_mul_mat_q<type, 96>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 104:
|
case 104:
|
||||||
launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
|
launch_mul_mat_q<type, 104>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 112:
|
case 112:
|
||||||
launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
|
launch_mul_mat_q<type, 112>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 120:
|
case 120:
|
||||||
launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
|
launch_mul_mat_q<type, 120>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
|
launch_mul_mat_q<type, 128>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
|
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
|
||||||
|
@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_MMQ_CASE(type) \
|
#define DECL_MMQ_CASE(type) \
|
||||||
template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
|
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
|
||||||
|
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||||
|
|
|
@ -735,6 +735,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
|
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
|
||||||
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||||
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
|
|
699
ggml-quants.c
699
ggml-quants.c
|
@ -8815,7 +8815,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
|
#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
|
||||||
static const int8_t keven_signs_q2xs[1024] = {
|
static const int8_t keven_signs_q2xs[1024] = {
|
||||||
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||||
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||||
|
@ -8948,6 +8948,61 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
*s = 0.125f * hsum_float_8(accumf);
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
|
uint32_t aux32[4];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
|
||||||
|
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||||
|
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
|
||||||
|
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
|
||||||
|
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
|
||||||
|
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
const uint16_t ls1 = aux32[1] >> 28;
|
||||||
|
const uint16_t ls2 = aux32[3] >> 28;
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector int v0 = vec_splats((int32_t)0);
|
const vector int v0 = vec_splats((int32_t)0);
|
||||||
vector float vsumf0 = vec_splats(0.0f);
|
vector float vsumf0 = vec_splats(0.0f);
|
||||||
|
@ -9291,6 +9346,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = 0.125f * hsum_float_8(accumf);
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
const __m128i mone = _mm_set1_epi8(1);
|
||||||
|
static const char block_sign_shuffle_mask_1[32] = {
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
||||||
|
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
||||||
|
};
|
||||||
|
static const char block_sign_shuffle_mask_2[32] = {
|
||||||
|
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
||||||
|
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
||||||
|
};
|
||||||
|
static const uint8_t bit_selector_mask_bytes[32] = {
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
};
|
||||||
|
|
||||||
|
const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
|
||||||
|
const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
|
||||||
|
const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
|
||||||
|
const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
|
||||||
|
const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
|
||||||
|
const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
|
||||||
|
|
||||||
|
static const uint8_t k_bit_helper[32] = {
|
||||||
|
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||||
|
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||||
|
};
|
||||||
|
const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
|
||||||
|
const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
|
||||||
|
const __m128i m511 = _mm_set1_epi16(511);
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||||
|
const __m128i m1 = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
uint64_t aux64;
|
||||||
|
|
||||||
|
// somewhat hacky, but gives a significant boost in performance
|
||||||
|
__m256i aux_gindex;
|
||||||
|
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
memcpy(&aux64, x[i].scales, 8);
|
||||||
|
__m128i stmp = _mm_set1_epi64x(aux64);
|
||||||
|
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
|
||||||
|
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
|
||||||
|
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
||||||
|
|
||||||
|
const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
|
||||||
|
const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
|
||||||
|
aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
|
||||||
|
|
||||||
|
const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
|
||||||
|
const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
|
||||||
|
const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
|
||||||
|
const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
|
||||||
|
const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
|
||||||
|
const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
|
||||||
|
|
||||||
|
const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
|
||||||
|
const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
|
||||||
|
const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
|
||||||
|
const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
|
||||||
|
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
|
||||||
|
const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
|
||||||
|
const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
|
||||||
|
const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
|
||||||
|
const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
|
||||||
|
|
||||||
|
// AVX2 full_signs_1 is full_sign_bits_0 here
|
||||||
|
// AVX2 full_signs_2 is full_sign_bits_1 here
|
||||||
|
__m128i signs_0, signs_1;
|
||||||
|
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
|
||||||
|
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
|
||||||
|
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||||
|
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||||
|
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
|
||||||
|
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
|
||||||
|
|
||||||
|
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
|
||||||
|
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
|
||||||
|
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||||
|
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
|
||||||
|
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
|
||||||
|
|
||||||
|
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
|
||||||
|
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
|
||||||
|
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||||
|
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||||
|
const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
|
||||||
|
const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
|
||||||
|
|
||||||
|
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
|
||||||
|
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
|
||||||
|
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||||
|
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||||
|
const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
|
||||||
|
const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
|
||||||
|
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
|
||||||
|
const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
|
||||||
|
const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
|
||||||
|
const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
|
||||||
|
|
||||||
|
__m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
|
||||||
|
const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||||
|
const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||||
|
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
|
||||||
|
const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||||
|
const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||||
|
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
|
||||||
|
const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||||
|
const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||||
|
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
|
||||||
|
const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||||
|
const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||||
|
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__loongarch_asx)
|
#elif defined(__loongarch_asx)
|
||||||
|
|
||||||
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
||||||
|
@ -9694,6 +9908,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = 0.125f * hsum_float_8(accumf);
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
};
|
||||||
|
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||||
|
const __m128i m1 = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
|
||||||
|
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
|
||||||
|
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
|
||||||
|
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
|
||||||
|
|
||||||
|
uint64_t aux64;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint8_t * restrict qs = x[i].qs;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
memcpy(&aux64, x[i].scales, 8);
|
||||||
|
const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
|
||||||
|
const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
|
||||||
|
const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
|
||||||
|
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
|
||||||
|
iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
|
||||||
|
iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
|
||||||
|
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
|
||||||
|
iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
|
||||||
|
qs += 8;
|
||||||
|
|
||||||
|
__m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
|
||||||
|
__m128i aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
|
||||||
|
|
||||||
|
aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
|
||||||
|
aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
|
||||||
|
|
||||||
|
signs += 4;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
@ -10020,6 +10326,63 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
*s = 0.25f * hsum_float_8(accumf);
|
*s = 0.25f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
|
uint32_t aux32[2];
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint8_t * restrict q3 = x[i].qs;
|
||||||
|
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
|
||||||
|
q3 += 8;
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
|
||||||
|
q3 += 8;
|
||||||
|
memcpy(aux32, gas, 8); gas += 8;
|
||||||
|
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
||||||
|
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
|
||||||
|
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||||
|
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
|
||||||
|
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
const uint16_t ls1 = aux32[0] >> 28;
|
||||||
|
const uint16_t ls2 = aux32[1] >> 28;
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.25f * hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
|
@ -10371,6 +10734,112 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = hsum_float_8(accumf);
|
*s = hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
};
|
||||||
|
|
||||||
|
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
|
||||||
|
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
|
||||||
|
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
|
||||||
|
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
|
||||||
|
|
||||||
|
const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
|
||||||
|
const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
|
||||||
|
const __m128i idx_mask = _mm_set1_epi32(256);
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
__m128i vec[4];
|
||||||
|
uint32_t index[16];
|
||||||
|
} index_t;
|
||||||
|
|
||||||
|
index_t idx;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint8_t * restrict qs = x[i].qs;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
|
||||||
|
const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
|
||||||
|
const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
|
||||||
|
idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||||
|
idx.vec[1] = idx.vec[0];
|
||||||
|
idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
|
||||||
|
idx.vec[3] = idx.vec[2];
|
||||||
|
|
||||||
|
idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
|
||||||
|
idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
|
||||||
|
idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
|
||||||
|
idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
|
||||||
|
|
||||||
|
idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
|
||||||
|
idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
|
||||||
|
idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
|
||||||
|
idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
|
||||||
|
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
|
||||||
|
|
||||||
|
__m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
|
||||||
|
__m128i aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
|
||||||
|
|
||||||
|
aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
|
||||||
|
aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
|
||||||
|
|
||||||
|
signs += 4;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
|
||||||
|
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
@ -10608,6 +11077,14 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(__AVX__)
|
||||||
|
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
|
||||||
|
const __m128i ax = _mm_sign_epi8(x, x);
|
||||||
|
const __m128i sy = _mm_sign_epi8(y, x);
|
||||||
|
return _mm_maddubs_epi16(ax, sy);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__AVX2__)
|
#if defined(__AVX2__)
|
||||||
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
||||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||||
|
@ -10725,6 +11202,54 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
|
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
__m256 accum = _mm256_setzero_ps();
|
||||||
|
float accum1 = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const int8_t * q8 = y[i].qs;
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint16_t * qh = x[i].qh;
|
||||||
|
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
int sumi1 = 0;
|
||||||
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||||
|
const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
|
||||||
|
const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
|
||||||
|
const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
|
||||||
|
const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
|
||||||
|
qs += 8;
|
||||||
|
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
|
||||||
|
const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
|
||||||
|
const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
|
||||||
|
const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
|
||||||
|
const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
||||||
|
const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
|
||||||
|
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
|
||||||
|
sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
|
||||||
|
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
|
||||||
|
accum1 += d * sumi1;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector unsigned char v0 = vec_splats((unsigned char)0x0);
|
const vector unsigned char v0 = vec_splats((unsigned char)0x0);
|
||||||
const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
|
const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
|
||||||
|
@ -11063,6 +11588,92 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
|
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
const __m128i mask = _mm_set1_epi16(0x7);
|
||||||
|
const __m128i mone = _mm_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
__m256 accum2 = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const int8_t * q8 = y[i].qs;
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||||
|
const __m128i q1b_1_0 = _mm_set_epi64x(
|
||||||
|
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
|
||||||
|
const __m128i q1b_1_1 = _mm_set_epi64x(
|
||||||
|
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
|
||||||
|
const __m128i q1b_2_0 = _mm_set_epi64x(
|
||||||
|
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
|
||||||
|
const __m128i q1b_2_1 = _mm_set_epi64x(
|
||||||
|
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
|
||||||
|
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
|
||||||
|
const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
|
||||||
|
const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
|
||||||
|
const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
|
||||||
|
|
||||||
|
const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
|
||||||
|
const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
|
||||||
|
const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
|
||||||
|
const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
|
||||||
|
const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
|
||||||
|
|
||||||
|
__m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
|
||||||
|
__m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
|
||||||
|
__m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
|
||||||
|
__m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
|
||||||
|
|
||||||
|
scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
|
||||||
|
scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
|
||||||
|
scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
|
||||||
|
scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
|
||||||
|
const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
|
||||||
|
const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
|
||||||
|
const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
|
||||||
|
const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
|
||||||
|
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
|
||||||
|
|
||||||
|
qs += 8; qh += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
|
||||||
|
|
||||||
|
accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
|
||||||
|
accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int sum1[2], sum2[2], delta[4];
|
int sum1[2], sum2[2], delta[4];
|
||||||
|
@ -11193,6 +11804,44 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
|
||||||
|
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||||
|
const __m128i mone = _mm_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
__m256 accum2 = _mm256_setzero_ps();
|
||||||
|
for (int ib = 0; ib < nb; ib += 2) {
|
||||||
|
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
|
||||||
|
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
|
||||||
|
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
|
||||||
|
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[0].qs + 1);
|
||||||
|
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[1].qs);
|
||||||
|
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[1].qs + 1);
|
||||||
|
|
||||||
|
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
|
||||||
|
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
|
||||||
|
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
|
||||||
|
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
|
||||||
|
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
|
||||||
|
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
|
||||||
|
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
|
||||||
|
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
|
||||||
|
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
|
||||||
|
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
|
||||||
|
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
|
||||||
|
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
|
||||||
|
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
|
||||||
|
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
|
||||||
|
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
|
||||||
|
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
|
||||||
|
|
||||||
|
y += 2;
|
||||||
|
x += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||||
const vector signed int v0 = vec_splats((int32_t)0);
|
const vector signed int v0 = vec_splats((int32_t)0);
|
||||||
|
@ -11383,6 +12032,54 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = hsum_float_8(accum);
|
*s = hsum_float_8(accum);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
|
||||||
|
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||||
|
|
||||||
|
__m256 accum = _mm256_setzero_ps();
|
||||||
|
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||||
|
const uint8_t * qs = x[ibl].qs;
|
||||||
|
const int8_t * q8 = y[ibl].qs;
|
||||||
|
uint16_t sh = x[ibl].scales_h;
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||||
|
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
|
||||||
|
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
|
||||||
|
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
|
||||||
|
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
|
||||||
|
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
|
||||||
|
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
|
||||||
|
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
|
||||||
|
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
|
||||||
|
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
|
||||||
|
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
|
||||||
|
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||||
|
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||||
|
sh >>= 4;
|
||||||
|
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
|
||||||
|
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
|
||||||
|
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
|
||||||
|
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
|
||||||
|
sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
|
||||||
|
}
|
||||||
|
__m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
|
||||||
|
__m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
|
||||||
|
accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
|
||||||
|
_mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accum);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||||
const vector int v0 = vec_splats((int32_t)0);
|
const vector int v0 = vec_splats((int32_t)0);
|
||||||
|
|
7006
ggml-sycl.cpp
7006
ggml-sycl.cpp
File diff suppressed because it is too large
Load diff
|
@ -14,5 +14,10 @@
|
||||||
#define GGML_SYCL_BACKEND_HPP
|
#define GGML_SYCL_BACKEND_HPP
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
|
#include "convert.hpp"
|
||||||
|
#include "dequantize.hpp"
|
||||||
|
#include "dmmv.hpp"
|
||||||
|
#include "mmq.hpp"
|
||||||
|
#include "mmvq.hpp"
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
544
ggml-sycl/convert.cpp
Normal file
544
ggml-sycl/convert.cpp
Normal file
|
@ -0,0 +1,544 @@
|
||||||
|
#include "convert.hpp"
|
||||||
|
#include "dequantize.hpp"
|
||||||
|
#include "presets.hpp"
|
||||||
|
|
||||||
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
|
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2));
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ib = i/qk; // block index
|
||||||
|
const int iqs = (i%qk)/qr; // quant index
|
||||||
|
const int iybs = i - i%qk; // y block start index
|
||||||
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
|
// dequantize
|
||||||
|
dfloat2 v;
|
||||||
|
dequantize_kernel(vx, ib, iqs, v);
|
||||||
|
|
||||||
|
y[iybs + iqs + 0] = v.x();
|
||||||
|
y[iybs + iqs + y_offset] = v.y();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
|
static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||||
|
dst_t *__restrict__ y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(
|
||||||
|
sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
#if QK_K == 256
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 64),
|
||||||
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q2_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q2_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
#if QK_K == 256
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 64),
|
||||||
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q3_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q3_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb32 = k / 32;
|
||||||
|
const int nb = (k + 255) / 256;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb32 = k / 32;
|
||||||
|
const int nb = (k + 255) / 256;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q4_1(vx, y, nb32, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q4_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
#if QK_K == 256
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 64),
|
||||||
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q5_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q5_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
#if QK_K == 256
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 64),
|
||||||
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q6_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q6_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq1_s(
|
||||||
|
vx, y, item_ct1, iq1s_grid_gpu
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq1_m(
|
||||||
|
vx, y, item_ct1, iq1s_grid_gpu
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq2_xxs(
|
||||||
|
vx, y, item_ct1, iq2xxs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq2_xs(
|
||||||
|
vx, y, item_ct1, iq2xs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq2_s(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq3_xxs(
|
||||||
|
vx, y, item_ct1, iq3xxs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq3_s(
|
||||||
|
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
#if QK_K == 64
|
||||||
|
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq4_xs(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq4_nl(vx, y, item_ct1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename src_t, typename dst_t>
|
||||||
|
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const src_t * x = (src_t *) vx;
|
||||||
|
|
||||||
|
y[i] = x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename src_t, typename dst_t>
|
||||||
|
static void convert_unary_sycl(const void *__restrict__ vx,
|
||||||
|
dst_t *__restrict__ y, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(
|
||||||
|
sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
convert_unary<src_t>(vx, y, k, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
return dequantize_row_q2_K_sycl;
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
return dequantize_row_q3_K_sycl;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
return dequantize_row_q4_K_sycl;
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
return dequantize_row_q5_K_sycl;
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
return dequantize_row_q6_K_sycl;
|
||||||
|
case GGML_TYPE_IQ1_S:
|
||||||
|
return dequantize_row_iq1_s_sycl;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
return dequantize_row_iq1_m_sycl;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
return dequantize_row_iq2_xxs_sycl;
|
||||||
|
case GGML_TYPE_IQ2_XS:
|
||||||
|
return dequantize_row_iq2_xs_sycl;
|
||||||
|
case GGML_TYPE_IQ2_S:
|
||||||
|
return dequantize_row_iq2_s_sycl;
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
return dequantize_row_iq3_xxs_sycl;
|
||||||
|
case GGML_TYPE_IQ3_S:
|
||||||
|
return dequantize_row_iq3_s_sycl;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
return dequantize_row_iq4_xs_sycl;
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
return dequantize_row_iq4_nl_sycl;
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
return convert_unary_sycl<float>;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
return dequantize_row_q4_0_sycl;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
return dequantize_row_q4_1_sycl;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
return dequantize_row_q2_K_sycl;
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
return dequantize_row_q3_K_sycl;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
return dequantize_row_q4_K_sycl;
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
return dequantize_row_q5_K_sycl;
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
return dequantize_row_q6_K_sycl;
|
||||||
|
case GGML_TYPE_IQ1_S:
|
||||||
|
return dequantize_row_iq1_s_sycl;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
return dequantize_row_iq1_m_sycl;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
return dequantize_row_iq2_xxs_sycl;
|
||||||
|
case GGML_TYPE_IQ2_XS:
|
||||||
|
return dequantize_row_iq2_xs_sycl;
|
||||||
|
case GGML_TYPE_IQ2_S:
|
||||||
|
return dequantize_row_iq2_s_sycl;
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
return dequantize_row_iq3_xxs_sycl;
|
||||||
|
case GGML_TYPE_IQ3_S:
|
||||||
|
return dequantize_row_iq3_s_sycl;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
return dequantize_row_iq4_xs_sycl;
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
return dequantize_row_iq4_nl_sycl;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
return convert_unary_sycl<sycl::half>;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
27
ggml-sycl/convert.hpp
Normal file
27
ggml-sycl/convert.hpp
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_CONVERT_HPP
|
||||||
|
#define GGML_SYCL_CONVERT_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
||||||
|
int k, dpct::queue_ptr stream);
|
||||||
|
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||||
|
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
||||||
|
|
||||||
|
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
|
||||||
|
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_CONVERT_HPP
|
690
ggml-sycl/dequantize.hpp
Normal file
690
ggml-sycl/dequantize.hpp
Normal file
|
@ -0,0 +1,690 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_DEQUANTIZE_HPP
|
||||||
|
#define GGML_SYCL_DEQUANTIZE_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||||
|
|
||||||
|
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
|
||||||
|
const int iqs, dfloat2 &v) {
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
v.x() = vui & 0xF;
|
||||||
|
v.y() = vui >> 4;
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_F16
|
||||||
|
// v = v - {8.0f, 8.0f};
|
||||||
|
// v = v * {d, d};
|
||||||
|
v.s0() = (v.s0() - 8.0f) * d;
|
||||||
|
v.s1() = (v.s1() - 8.0f) * d;
|
||||||
|
|
||||||
|
#else
|
||||||
|
v.x() = (v.x() - 8.0f) * d;
|
||||||
|
v.y() = (v.y() - 8.0f) * d;
|
||||||
|
#endif // GGML_SYCL_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
|
||||||
|
const int iqs, dfloat2 &v) {
|
||||||
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].dm[0];
|
||||||
|
const dfloat m = x[ib].dm[1];
|
||||||
|
|
||||||
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
v.x() = vui & 0xF;
|
||||||
|
v.y() = vui >> 4;
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_F16
|
||||||
|
// v = v * {d, d};
|
||||||
|
// v = v + {m, m};
|
||||||
|
v.s0() = (v.s0() * d) + m;
|
||||||
|
v.s1() = (v.s1() * d) + m;
|
||||||
|
|
||||||
|
#else
|
||||||
|
v.x() = (v.x() * d) + m;
|
||||||
|
v.y() = (v.y() * d) + m;
|
||||||
|
#endif // GGML_SYCL_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
|
||||||
|
const int iqs, dfloat2 &v) {
|
||||||
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
|
v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_F16
|
||||||
|
// v = v - {16.0f, 16.0f};
|
||||||
|
// v = v * {d, d};
|
||||||
|
v.s0() = (v.s0() - 16.0f) * d;
|
||||||
|
v.s1() = (v.s1() - 16.0f) * d;
|
||||||
|
|
||||||
|
#else
|
||||||
|
v.x() = (v.x() - 16.0f) * d;
|
||||||
|
v.y() = (v.y() - 16.0f) * d;
|
||||||
|
#endif // GGML_SYCL_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
|
||||||
|
const int iqs, dfloat2 &v) {
|
||||||
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].dm[0];
|
||||||
|
const dfloat m = x[ib].dm[1];
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
|
v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_F16
|
||||||
|
// v = v * {d, d};
|
||||||
|
// v = v + {m, m};
|
||||||
|
v.s0() = (v.s0() * d) + m;
|
||||||
|
v.s1() = (v.s1() * d) + m;
|
||||||
|
#else
|
||||||
|
v.x() = (v.x() * d) + m;
|
||||||
|
v.y() = (v.y() * d) + m;
|
||||||
|
#endif // GGML_SYCL_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
|
||||||
|
const int iqs, dfloat2 &v) {
|
||||||
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
v.x() = x[ib].qs[iqs + 0];
|
||||||
|
v.y() = x[ib].qs[iqs + 1];
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_F16
|
||||||
|
// v = v * {d, d};
|
||||||
|
v.s0() *= d;
|
||||||
|
v.s1() *= d;
|
||||||
|
#else
|
||||||
|
v.x() *= d;
|
||||||
|
v.y() *= d;
|
||||||
|
#endif // GGML_SYCL_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int ib = 8*i + ir;
|
||||||
|
if (ib >= nb32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
||||||
|
const float d = sycl::vec<sycl::half, 1>(x->d)
|
||||||
|
.convert<float, sycl::rounding_mode::automatic>()[0];
|
||||||
|
const float dm = -8*d;
|
||||||
|
|
||||||
|
const uint8_t * q = x->qs + 4*il;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
y[l+ 0] = d * (q[l] & 0xF) + dm;
|
||||||
|
y[l+16] = d * (q[l] >> 4) + dm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int ib = 8*i + ir;
|
||||||
|
if (ib >= nb32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
|
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
||||||
|
const sycl::float2 d =
|
||||||
|
x->dm.convert<float, sycl::rounding_mode::automatic>();
|
||||||
|
|
||||||
|
const uint8_t * q = x->qs + 4*il;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
y[l + 0] = d.x() * (q[l] & 0xF) + d.y();
|
||||||
|
y[l + 16] = d.x() * (q[l] >> 4) + d.y();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//================================== k-quants
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int n = tid/32;
|
||||||
|
const int l = tid - 32*n;
|
||||||
|
const int is = 8*n + l/16;
|
||||||
|
|
||||||
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
|
|
||||||
|
float dall = x[i].dm[0];
|
||||||
|
float dmin = x[i].dm[1];
|
||||||
|
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||||
|
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||||
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||||
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||||
|
#else
|
||||||
|
const int is = tid/16; // 0 or 1
|
||||||
|
const int il = tid%16; // 0...15
|
||||||
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
float dall = x[i].dm[0];
|
||||||
|
float dmin = x[i].dm[1];
|
||||||
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||||
|
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
const int r = item_ct1.get_local_id(2) / 4;
|
||||||
|
const int tid = r/2;
|
||||||
|
const int is0 = r%2;
|
||||||
|
const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
||||||
|
const int n = tid / 4;
|
||||||
|
const int j = tid - 4*n;
|
||||||
|
|
||||||
|
uint8_t m = 1 << (4*n + j);
|
||||||
|
int is = 8*n + 2*j + is0;
|
||||||
|
int shift = 2*j;
|
||||||
|
|
||||||
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
|
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
|
||||||
|
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
|
||||||
|
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
|
||||||
|
float d_all = x[i].d;
|
||||||
|
float dl = d_all * (us - 32);
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||||
|
const uint8_t * q = x[i].qs + 32*n;
|
||||||
|
const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||||
|
#else
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int is = tid/16; // 0 or 1
|
||||||
|
const int il = tid%16; // 0...15
|
||||||
|
const int im = il/8; // 0...1
|
||||||
|
const int in = il%8; // 0...7
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||||
|
const float d = (float)x[i].d;
|
||||||
|
|
||||||
|
if (is == 0) {
|
||||||
|
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
||||||
|
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
||||||
|
} else {
|
||||||
|
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
||||||
|
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||||
|
if (j < 4) {
|
||||||
|
d = q[j] & 63; m = q[j + 4] & 63;
|
||||||
|
} else {
|
||||||
|
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||||
|
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int is = 2*il;
|
||||||
|
const int n = 4;
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
|
const float dall = x[i].dm[0];
|
||||||
|
const float dmin = x[i].dm[1];
|
||||||
|
|
||||||
|
const uint8_t * q = x[i].qs + 32*il + n*ir;
|
||||||
|
|
||||||
|
uint8_t sc, m;
|
||||||
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||||
|
const float d1 = dall * sc; const float m1 = dmin * m;
|
||||||
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||||
|
const float d2 = dall * sc; const float m2 = dmin * m;
|
||||||
|
for (int l = 0; l < n; ++l) {
|
||||||
|
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
||||||
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const uint8_t * q = x[i].qs;
|
||||||
|
dst_t * y = yy + i*QK_K;
|
||||||
|
const float d = (float)x[i].dm[0];
|
||||||
|
const float m = (float)x[i].dm[1];
|
||||||
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
||||||
|
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/16; // il is in 0...3
|
||||||
|
const int ir = tid%16; // ir is in 0...15
|
||||||
|
const int is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
|
const float dall = x[i].dm[0];
|
||||||
|
const float dmin = x[i].dm[1];
|
||||||
|
|
||||||
|
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
|
||||||
|
const uint8_t * qh = x[i].qh + 2*ir;
|
||||||
|
|
||||||
|
uint8_t sc, m;
|
||||||
|
get_scale_min_k4(is + 0, x[i].scales, sc, m);
|
||||||
|
const float d1 = dall * sc; const float m1 = dmin * m;
|
||||||
|
get_scale_min_k4(is + 1, x[i].scales, sc, m);
|
||||||
|
const float d2 = dall * sc; const float m2 = dmin * m;
|
||||||
|
|
||||||
|
uint8_t hm = 1 << (2*il);
|
||||||
|
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
|
||||||
|
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
|
||||||
|
hm <<= 1;
|
||||||
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
||||||
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||||
|
#else
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const uint8_t q = x[i].qs[tid];
|
||||||
|
const int im = tid/8; // 0...3
|
||||||
|
const int in = tid%8; // 0...7
|
||||||
|
const int is = tid/16; // 0 or 1
|
||||||
|
const uint8_t h = x[i].qh[in] >> im;
|
||||||
|
const float d = x[i].d;
|
||||||
|
dst_t * y = yy + i*QK_K + tid;
|
||||||
|
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
||||||
|
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const block_q6_K * x = (const block_q6_K *) vx;
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
|
||||||
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int ip = tid/32; // ip is 0 or 1
|
||||||
|
const int il = tid - 32*ip; // 0...32
|
||||||
|
const int is = 8*ip + il/16;
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||||
|
|
||||||
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
const uint8_t * ql = x[i].ql + 64*ip + il;
|
||||||
|
const uint8_t qh = x[i].qh[32*ip + il];
|
||||||
|
const int8_t * sc = x[i].scales + is;
|
||||||
|
|
||||||
|
y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||||
|
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||||
|
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||||
|
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||||
|
#else
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int ip = tid/16; // 0 or 1
|
||||||
|
const int il = tid - 16*ip; // 0...15
|
||||||
|
|
||||||
|
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||||
|
|
||||||
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
const uint8_t ql = x[i].ql[16*ip + il];
|
||||||
|
const uint8_t qh = x[i].qh[il] >> (2*ip);
|
||||||
|
const int8_t * sc = x[i].scales;
|
||||||
|
|
||||||
|
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||||
|
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint64_t *iq2xxs_grid_ptr,
|
||||||
|
const uint8_t *ksigns_iq2xs_ptr,
|
||||||
|
const uint8_t *kmask_iq2xs_ptr) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);
|
||||||
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint64_t *iq2xs_grid,
|
||||||
|
const uint8_t *ksigns_iq2xs,
|
||||||
|
const uint8_t *kmask_iq2xs) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 8; ++j)
|
||||||
|
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint32_t *iq3xxs_grid,
|
||||||
|
const uint8_t *ksigns_iq2xs,
|
||||||
|
const uint8_t *kmask_iq2xs) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||||
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
|
||||||
|
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||||
|
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||||
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint8_t * qs = x[i].qs + 8*ib;
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
||||||
|
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
|
||||||
|
const uint8_t signs = x[i].signs[4*ib + il];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||||
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint32_t *iq1s_grid_gpu) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||||
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||||
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
||||||
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = d * (q[j] + delta);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
const uint32_t *iq1s_grid_gpu) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
||||||
|
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
||||||
|
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
||||||
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
|
||||||
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = d * (q[j] + delta);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
|
const float d = (float)x[ib].d;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||||
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
__dpct_inline__ static void
|
||||||
|
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i = item_ct1.get_group(2);
|
||||||
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||||
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_DEQUANTIZE_HPP
|
1022
ggml-sycl/dmmv.cpp
Normal file
1022
ggml-sycl/dmmv.cpp
Normal file
File diff suppressed because it is too large
Load diff
27
ggml-sycl/dmmv.hpp
Normal file
27
ggml-sycl/dmmv.hpp
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_DMMV_HPP
|
||||||
|
#define GGML_SYCL_DMMV_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||||
|
ggml_backend_sycl_context & ctx,
|
||||||
|
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||||
|
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||||
|
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
||||||
|
const dpct::queue_ptr &stream);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_DMMV_HPP
|
|
@ -589,94 +589,75 @@ namespace dpct
|
||||||
}
|
}
|
||||||
|
|
||||||
/// dpct device extension
|
/// dpct device extension
|
||||||
class device_ext : public sycl::device
|
class device_ext : public sycl::device {
|
||||||
{
|
|
||||||
typedef std::mutex mutex_type;
|
typedef std::mutex mutex_type;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
device_ext() : sycl::device(), _ctx(*this) {}
|
device_ext() : sycl::device() {}
|
||||||
~device_ext()
|
~device_ext() {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
clear_queues();
|
clear_queues();
|
||||||
}
|
}
|
||||||
device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this)
|
device_ext(const sycl::device &base) : sycl::device(base) {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
init_queues();
|
init_queues();
|
||||||
}
|
}
|
||||||
|
|
||||||
int is_native_atomic_supported() { return 0; }
|
int is_native_atomic_supported() { return 0; }
|
||||||
int get_major_version() const
|
int get_major_version() const { return dpct::get_major_version(*this); }
|
||||||
{
|
|
||||||
return dpct::get_major_version(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_minor_version() const
|
int get_minor_version() const { return dpct::get_minor_version(*this); }
|
||||||
{
|
|
||||||
return dpct::get_minor_version(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_max_compute_units() const
|
int get_max_compute_units() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_compute_units();
|
return get_device_info().get_max_compute_units();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the maximum clock frequency of this device in KHz.
|
/// Return the maximum clock frequency of this device in KHz.
|
||||||
int get_max_clock_frequency() const
|
int get_max_clock_frequency() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_clock_frequency();
|
return get_device_info().get_max_clock_frequency();
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_integrated() const { return get_device_info().get_integrated(); }
|
int get_integrated() const { return get_device_info().get_integrated(); }
|
||||||
|
|
||||||
int get_max_sub_group_size() const
|
int get_max_sub_group_size() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_sub_group_size();
|
return get_device_info().get_max_sub_group_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_max_register_size_per_work_group() const
|
int get_max_register_size_per_work_group() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_register_size_per_work_group();
|
return get_device_info().get_max_register_size_per_work_group();
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_max_work_group_size() const
|
int get_max_work_group_size() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_work_group_size();
|
return get_device_info().get_max_work_group_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_mem_base_addr_align() const
|
int get_mem_base_addr_align() const {
|
||||||
{
|
|
||||||
return get_info<sycl::info::device::mem_base_addr_align>();
|
return get_info<sycl::info::device::mem_base_addr_align>();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t get_global_mem_size() const
|
size_t get_global_mem_size() const {
|
||||||
{
|
|
||||||
return get_device_info().get_global_mem_size();
|
return get_device_info().get_global_mem_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t get_max_mem_alloc_size() const
|
size_t get_max_mem_alloc_size() const {
|
||||||
{
|
|
||||||
return get_device_info().get_max_mem_alloc_size();
|
return get_device_info().get_max_mem_alloc_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the number of bytes of free and total memory on the SYCL device.
|
/// Get the number of bytes of free and total memory on the SYCL device.
|
||||||
/// \param [out] free_memory The number of bytes of free memory on the SYCL device.
|
/// \param [out] free_memory The number of bytes of free memory on the
|
||||||
/// \param [out] total_memory The number of bytes of total memory on the SYCL device.
|
/// SYCL device. \param [out] total_memory The number of bytes of total
|
||||||
void get_memory_info(size_t &free_memory, size_t &total_memory)
|
/// memory on the SYCL device.
|
||||||
{
|
void get_memory_info(size_t &free_memory, size_t &total_memory) {
|
||||||
total_memory = get_device_info().get_global_mem_size();
|
total_memory = get_device_info().get_global_mem_size();
|
||||||
const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not "
|
const char *warning_info =
|
||||||
|
"get_memory_info: [warning] ext_intel_free_memory is not "
|
||||||
"supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
|
"supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
|
||||||
"use total memory as free memory";
|
"use total memory as free memory";
|
||||||
#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
|
#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
|
||||||
if (!has(sycl::aspect::ext_intel_free_memory))
|
if (!has(sycl::aspect::ext_intel_free_memory)) {
|
||||||
{
|
|
||||||
std::cerr << warning_info << std::endl;
|
std::cerr << warning_info << std::endl;
|
||||||
free_memory = total_memory;
|
free_memory = total_memory;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
|
free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -690,164 +671,139 @@ namespace dpct
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_device_info(device_info &out) const
|
void get_device_info(device_info &out) const {
|
||||||
{
|
|
||||||
dpct::get_device_info(out, *this);
|
dpct::get_device_info(out, *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
device_info get_device_info() const
|
device_info get_device_info() const {
|
||||||
{
|
|
||||||
device_info prop;
|
device_info prop;
|
||||||
dpct::get_device_info(prop, *this);
|
dpct::get_device_info(prop, *this);
|
||||||
return prop;
|
return prop;
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset()
|
void reset() {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
clear_queues();
|
clear_queues();
|
||||||
init_queues();
|
init_queues();
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue &in_order_queue() { return *_q_in_order; }
|
sycl::queue &in_order_queue() { return _q_in_order; }
|
||||||
|
|
||||||
sycl::queue &out_of_order_queue() { return *_q_out_of_order; }
|
sycl::queue &out_of_order_queue() { return _q_out_of_order; }
|
||||||
|
|
||||||
sycl::queue &default_queue()
|
sycl::queue &default_queue() { return in_order_queue(); }
|
||||||
{
|
|
||||||
return in_order_queue();
|
|
||||||
}
|
|
||||||
|
|
||||||
void queues_wait_and_throw()
|
void queues_wait_and_throw() {
|
||||||
{
|
|
||||||
std::unique_lock<mutex_type> lock(m_mutex);
|
std::unique_lock<mutex_type> lock(m_mutex);
|
||||||
std::vector<std::shared_ptr<sycl::queue>> current_queues(
|
|
||||||
_queues);
|
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
for (const auto &q : current_queues)
|
for (auto &q : _queues) {
|
||||||
{
|
q.wait_and_throw();
|
||||||
q->wait_and_throw();
|
|
||||||
}
|
}
|
||||||
// Guard the destruct of current_queues to make sure the ref count is safe.
|
// Guard the destruct of current_queues to make sure the ref count is
|
||||||
|
// safe.
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue *create_queue(bool enable_exception_handler = false)
|
sycl::queue create_queue(bool enable_exception_handler = false) {
|
||||||
{
|
|
||||||
return create_in_order_queue(enable_exception_handler);
|
return create_in_order_queue(enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue *create_queue(sycl::context context, sycl::device device,
|
sycl::queue create_queue(sycl::device device,
|
||||||
bool enable_exception_handler = false) {
|
bool enable_exception_handler = false) {
|
||||||
return create_in_order_queue(context, device, enable_exception_handler);
|
return create_in_order_queue(device, enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue *create_in_order_queue(bool enable_exception_handler = false) {
|
sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(enable_exception_handler,
|
return create_queue_impl(enable_exception_handler,
|
||||||
sycl::property::queue::in_order());
|
sycl::property::queue::in_order());
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue *create_in_order_queue(sycl::context context, sycl::device device,
|
sycl::queue create_in_order_queue(sycl::device device,
|
||||||
bool enable_exception_handler = false) {
|
bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(context, device, enable_exception_handler,
|
return create_queue_impl(device, enable_exception_handler,
|
||||||
sycl::property::queue::in_order());
|
sycl::property::queue::in_order());
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) {
|
sycl::queue create_out_of_order_queue(
|
||||||
|
bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(enable_exception_handler);
|
return create_queue_impl(enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
void destroy_queue(sycl::queue *&queue)
|
void destroy_queue(sycl::queue queue) {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
_queues.erase(std::remove_if(_queues.begin(), _queues.end(),
|
_queues.clear();
|
||||||
[=](const std::shared_ptr<sycl::queue> &q) -> bool
|
|
||||||
{
|
|
||||||
return q.get() == queue;
|
|
||||||
}),
|
|
||||||
_queues.end());
|
|
||||||
queue = nullptr;
|
|
||||||
}
|
}
|
||||||
void set_saved_queue(sycl::queue *q)
|
void set_saved_queue(sycl::queue q) {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
_saved_queue = q;
|
_saved_queue = q;
|
||||||
}
|
}
|
||||||
sycl::queue *get_saved_queue() const
|
sycl::queue get_saved_queue() const {
|
||||||
{
|
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return _saved_queue;
|
return _saved_queue;
|
||||||
}
|
}
|
||||||
sycl::context get_context() const { return _ctx; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void clear_queues()
|
void clear_queues() { _queues.clear(); }
|
||||||
{
|
|
||||||
_queues.clear();
|
|
||||||
_q_in_order = _q_out_of_order = _saved_queue = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_queues()
|
void init_queues() {
|
||||||
{
|
_q_in_order =
|
||||||
_q_in_order = create_queue_impl(true, sycl::property::queue::in_order());
|
create_queue_impl(true, sycl::property::queue::in_order());
|
||||||
_q_out_of_order = create_queue_impl(true);
|
_q_out_of_order = create_queue_impl(true);
|
||||||
_saved_queue = &default_queue();
|
_saved_queue = default_queue();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Caller should acquire resource \p m_mutex before calling this function.
|
/// Caller should acquire resource \p m_mutex before calling this
|
||||||
|
/// function.
|
||||||
template <class... Properties>
|
template <class... Properties>
|
||||||
sycl::queue *create_queue_impl(bool enable_exception_handler,
|
sycl::queue create_queue_impl(bool enable_exception_handler,
|
||||||
Properties... properties)
|
Properties... properties) {
|
||||||
{
|
|
||||||
sycl::async_handler eh = {};
|
sycl::async_handler eh = {};
|
||||||
if (enable_exception_handler)
|
if (enable_exception_handler) {
|
||||||
{
|
|
||||||
eh = exception_handler;
|
eh = exception_handler;
|
||||||
}
|
}
|
||||||
_queues.push_back(std::make_shared<sycl::queue>(
|
auto q = sycl::queue(*this, eh,
|
||||||
_ctx, *this, eh,
|
|
||||||
sycl::property_list(
|
sycl::property_list(
|
||||||
#ifdef DPCT_PROFILING_ENABLED
|
#ifdef DPCT_PROFILING_ENABLED
|
||||||
sycl::property::queue::enable_profiling(),
|
sycl::property::queue::enable_profiling(),
|
||||||
#endif
|
#endif
|
||||||
properties...)));
|
properties...));
|
||||||
|
_queues.push_back(q);
|
||||||
|
|
||||||
return _queues.back().get();
|
return _queues.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class... Properties>
|
template <class... Properties>
|
||||||
sycl::queue *create_queue_impl(sycl::context context, sycl::device device,
|
sycl::queue create_queue_impl(sycl::device device,
|
||||||
bool enable_exception_handler,
|
bool enable_exception_handler,
|
||||||
Properties... properties) {
|
Properties... properties) {
|
||||||
sycl::async_handler eh = {};
|
sycl::async_handler eh = {};
|
||||||
if (enable_exception_handler) {
|
if (enable_exception_handler) {
|
||||||
eh = exception_handler;
|
eh = exception_handler;
|
||||||
}
|
}
|
||||||
_queues.push_back(std::make_shared<sycl::queue>(
|
_queues.push_back(
|
||||||
context, device, eh,
|
sycl::queue(device, eh,
|
||||||
sycl::property_list(
|
sycl::property_list(
|
||||||
#ifdef DPCT_PROFILING_ENABLED
|
#ifdef DPCT_PROFILING_ENABLED
|
||||||
sycl::property::queue::enable_profiling(),
|
sycl::property::queue::enable_profiling(),
|
||||||
#endif
|
#endif
|
||||||
properties...)));
|
properties...)));
|
||||||
|
|
||||||
return _queues.back().get();
|
return _queues.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_version(int &major, int &minor) const
|
void get_version(int &major, int &minor) const {
|
||||||
{
|
|
||||||
detail::get_version(*this, major, minor);
|
detail::get_version(*this, major, minor);
|
||||||
}
|
}
|
||||||
sycl::queue *_q_in_order, *_q_out_of_order;
|
sycl::queue _q_in_order, _q_out_of_order;
|
||||||
sycl::queue *_saved_queue;
|
sycl::queue _saved_queue;
|
||||||
sycl::context _ctx;
|
std::vector<sycl::queue> _queues;
|
||||||
std::vector<std::shared_ptr<sycl::queue>> _queues;
|
|
||||||
mutable mutex_type m_mutex;
|
mutable mutex_type m_mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/// device manager
|
/// device manager
|
||||||
class dev_mgr
|
class dev_mgr
|
||||||
{
|
{
|
||||||
|
|
3031
ggml-sycl/mmq.cpp
Normal file
3031
ggml-sycl/mmq.cpp
Normal file
File diff suppressed because it is too large
Load diff
33
ggml-sycl/mmq.hpp
Normal file
33
ggml-sycl/mmq.hpp
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_MMQ_HPP
|
||||||
|
#define GGML_SYCL_MMQ_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_mul_mat_q(
|
||||||
|
ggml_backend_sycl_context & ctx,
|
||||||
|
const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1,
|
||||||
|
ggml_tensor* dst,
|
||||||
|
const char* src0_dd_i,
|
||||||
|
const float* src1_ddf_i,
|
||||||
|
const char* src1_ddq_i,
|
||||||
|
float* dst_dd_i,
|
||||||
|
const int64_t row_low,
|
||||||
|
const int64_t row_high,
|
||||||
|
const int64_t src1_ncols,
|
||||||
|
const int64_t src1_padded_row_size,
|
||||||
|
const dpct::queue_ptr& stream);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_MMQ_HPP
|
1024
ggml-sycl/mmvq.cpp
Normal file
1024
ggml-sycl/mmvq.cpp
Normal file
File diff suppressed because it is too large
Load diff
27
ggml-sycl/mmvq.hpp
Normal file
27
ggml-sycl/mmvq.hpp
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_MMVQ_HPP
|
||||||
|
#define GGML_SYCL_MMVQ_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_sycl_op_mul_mat_vec_q(
|
||||||
|
ggml_backend_sycl_context & ctx,
|
||||||
|
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||||
|
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||||
|
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
||||||
|
const dpct::queue_ptr &stream);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_MMVQ_HPP
|
|
@ -18,8 +18,6 @@
|
||||||
#define GGML_SYCL_MAX_DEVICES 48
|
#define GGML_SYCL_MAX_DEVICES 48
|
||||||
#define GGML_SYCL_NAME "SYCL"
|
#define GGML_SYCL_NAME "SYCL"
|
||||||
|
|
||||||
// FIXME: 1024 from cuda
|
|
||||||
#define GROUP_SIZE 1024
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
|
|
1161
ggml-sycl/vecdotq.hpp
Normal file
1161
ggml-sycl/vecdotq.hpp
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1745,31 +1745,37 @@ void ggml_vk_instance_init() {
|
||||||
|
|
||||||
// Default to using all dedicated GPUs
|
// Default to using all dedicated GPUs
|
||||||
for (size_t i = 0; i < devices.size(); i++) {
|
for (size_t i = 0; i < devices.size(); i++) {
|
||||||
vk::PhysicalDeviceProperties props = devices[i].getProperties();
|
vk::PhysicalDeviceProperties2 new_props;
|
||||||
|
vk::PhysicalDeviceDriverProperties new_driver;
|
||||||
|
vk::PhysicalDeviceIDProperties new_id;
|
||||||
|
new_props.pNext = &new_driver;
|
||||||
|
new_driver.pNext = &new_id;
|
||||||
|
devices[i].getProperties2(&new_props);
|
||||||
|
|
||||||
if (props.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
|
if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
|
||||||
// Check if there are two physical devices corresponding to the same GPU
|
// Check if there are two physical devices corresponding to the same GPU
|
||||||
auto old_device = std::find_if(
|
auto old_device = std::find_if(
|
||||||
vk_instance.device_indices.begin(),
|
vk_instance.device_indices.begin(),
|
||||||
vk_instance.device_indices.end(),
|
vk_instance.device_indices.end(),
|
||||||
[&devices, &props](const size_t k){ return devices[k].getProperties().deviceID == props.deviceID; }
|
[&devices, &new_id](const size_t k){
|
||||||
|
vk::PhysicalDeviceProperties2 old_props;
|
||||||
|
vk::PhysicalDeviceIDProperties old_id;
|
||||||
|
old_props.pNext = &old_id;
|
||||||
|
devices[k].getProperties2(&old_props);
|
||||||
|
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||||
|
}
|
||||||
);
|
);
|
||||||
if (old_device == vk_instance.device_indices.end()) {
|
if (old_device == vk_instance.device_indices.end()) {
|
||||||
vk_instance.device_indices.push_back(i);
|
vk_instance.device_indices.push_back(i);
|
||||||
} else {
|
} else {
|
||||||
// There can be two physical devices corresponding to the same GPU if there are 2 different drivers
|
// There can be two physical devices corresponding to the same GPU if there are 2 different drivers
|
||||||
// This can cause error when splitting layers aross the devices, need to keep only 1
|
// This can cause error when splitting layers aross the devices, need to keep only 1
|
||||||
VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same device id");
|
VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
|
||||||
|
|
||||||
vk::PhysicalDeviceProperties2 old_prop;
|
vk::PhysicalDeviceProperties2 old_props;
|
||||||
vk::PhysicalDeviceDriverProperties old_driver;
|
vk::PhysicalDeviceDriverProperties old_driver;
|
||||||
old_prop.pNext = &old_driver;
|
old_props.pNext = &old_driver;
|
||||||
devices[*old_device].getProperties2(&old_prop);
|
devices[*old_device].getProperties2(&old_props);
|
||||||
|
|
||||||
vk::PhysicalDeviceProperties2 new_prop;
|
|
||||||
vk::PhysicalDeviceDriverProperties new_driver;
|
|
||||||
new_prop.pNext = &new_driver;
|
|
||||||
devices[i].getProperties2(&new_prop);
|
|
||||||
|
|
||||||
std::map<vk::DriverId, int> driver_priorities {};
|
std::map<vk::DriverId, int> driver_priorities {};
|
||||||
int old_priority = std::numeric_limits<int>::max();
|
int old_priority = std::numeric_limits<int>::max();
|
||||||
|
@ -1777,7 +1783,7 @@ void ggml_vk_instance_init() {
|
||||||
|
|
||||||
// Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
|
// Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
|
||||||
// Smaller number -> higher priority
|
// Smaller number -> higher priority
|
||||||
switch (old_prop.properties.vendorID) {
|
switch (old_props.properties.vendorID) {
|
||||||
case VK_VENDOR_ID_AMD:
|
case VK_VENDOR_ID_AMD:
|
||||||
driver_priorities[vk::DriverId::eMesaRadv] = 1;
|
driver_priorities[vk::DriverId::eMesaRadv] = 1;
|
||||||
driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
|
driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
|
||||||
|
|
205
ggml.c
205
ggml.c
|
@ -1761,9 +1761,8 @@ struct ggml_compute_state_shared {
|
||||||
int n_threads;
|
int n_threads;
|
||||||
|
|
||||||
// synchronization primitives
|
// synchronization primitives
|
||||||
atomic_int n_active; // num active threads
|
atomic_int n_barrier;
|
||||||
atomic_int node_n; // active graph node
|
atomic_int n_barrier_passed;
|
||||||
atomic_int node_task; // active graph node task phase
|
|
||||||
|
|
||||||
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
||||||
void* abort_callback_data;
|
void* abort_callback_data;
|
||||||
|
@ -19027,47 +19026,49 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
||||||
return n_tasks;
|
return n_tasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_compute_state * state, const bool do_yield) {
|
#ifdef GGML_USE_OPENMP
|
||||||
// wait for other threads to finish
|
static void ggml_barrier(struct ggml_compute_state * state) {
|
||||||
const int last_node_n = * node_n;
|
if (state->shared->n_threads == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp barrier
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
static void ggml_barrier(struct ggml_compute_state * state) {
|
||||||
|
if (state->shared->n_threads == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic_int * n_barrier = &state->shared->n_barrier;
|
||||||
|
atomic_int * n_barrier_passed = &state->shared->n_barrier_passed;
|
||||||
|
|
||||||
|
int n_threads = state->shared->n_threads;
|
||||||
|
int passed_old = atomic_load(n_barrier_passed);
|
||||||
|
|
||||||
|
if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
|
||||||
|
// last thread
|
||||||
|
atomic_store(n_barrier, 0);
|
||||||
|
atomic_fetch_add(n_barrier_passed, 1);
|
||||||
|
} else {
|
||||||
|
// wait for other threads
|
||||||
|
//while (atomic_load(n_barrier_passed) == passed_old) {
|
||||||
|
//}
|
||||||
|
const int n_spin_before_sleep = 100000;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (do_yield) {
|
for (int i = 0; i < n_spin_before_sleep; i++) {
|
||||||
|
if (atomic_load(n_barrier_passed) != passed_old) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#if defined(__SSE3__)
|
||||||
|
_mm_pause();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
sched_yield();
|
sched_yield();
|
||||||
}
|
}
|
||||||
|
|
||||||
*node_n = atomic_load(&state->shared->node_n);
|
|
||||||
if (*node_n != last_node_n) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__SSE3__)
|
|
||||||
// Tell the processor we're spinning. It's a processor hint for spinlocks.
|
|
||||||
_mm_pause();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_compute_state * state, const bool do_yield) {
|
|
||||||
// wait for other threads to finish
|
|
||||||
const int last_task_phase = *task_phase;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (do_yield) {
|
|
||||||
sched_yield();
|
|
||||||
}
|
|
||||||
|
|
||||||
*task_phase = atomic_load(&state->shared->node_task);
|
|
||||||
if (*task_phase != last_task_phase) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__SSE3__)
|
|
||||||
// Tell the processor we're spinning. It's a processor hint for spinlocks.
|
|
||||||
_mm_pause();
|
|
||||||
#endif
|
#endif
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||||
|
@ -19075,136 +19076,54 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||||
const struct ggml_cplan * cplan = state->shared->cplan;
|
const struct ggml_cplan * cplan = state->shared->cplan;
|
||||||
|
|
||||||
|
const int ith = state->ith;
|
||||||
const int n_threads = state->shared->n_threads;
|
const int n_threads = state->shared->n_threads;
|
||||||
|
|
||||||
set_numa_thread_affinity(state->ith);
|
set_numa_thread_affinity(ith);
|
||||||
|
|
||||||
int node_n = -1;
|
struct ggml_compute_params params = {
|
||||||
int task_phase = GGML_TASK_TYPE_FINALIZE;
|
/*.type =*/ GGML_TASK_TYPE_INIT,
|
||||||
|
/*.ith =*/ ith,
|
||||||
|
/*.nth =*/ state->shared->n_threads,
|
||||||
|
/*.wsize =*/ cplan->work_size,
|
||||||
|
/*.wdata =*/ cplan->work_data,
|
||||||
|
};
|
||||||
|
|
||||||
while (true) {
|
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
||||||
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||||
state->shared->node_n += 1;
|
|
||||||
state->ec = GGML_STATUS_ABORTED;
|
state->ec = GGML_STATUS_ABORTED;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
|
|
||||||
// all other threads are finished and spinning
|
|
||||||
// do finalize and init here so we don't have synchronize again
|
|
||||||
struct ggml_compute_params params = {
|
|
||||||
/*.type =*/ GGML_TASK_TYPE_FINALIZE,
|
|
||||||
/*.ith =*/ 0,
|
|
||||||
/*.nth =*/ 0,
|
|
||||||
/*.wsize =*/ cplan->work_size,
|
|
||||||
/*.wdata =*/ cplan->work_data,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (node_n != -1) {
|
|
||||||
/* FINALIZE */
|
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
|
||||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
|
||||||
params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
|
|
||||||
ggml_compute_forward(¶ms, node, state);
|
|
||||||
}
|
|
||||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
|
||||||
}
|
|
||||||
|
|
||||||
// distribute new work or execute it direct if 1T
|
|
||||||
while (++node_n < cgraph->n_nodes) {
|
|
||||||
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
|
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||||
const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
|
const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
|
||||||
|
|
||||||
state->shared->perf_node_start_cycles = ggml_perf_cycles();
|
|
||||||
state->shared->perf_node_start_time_us = ggml_perf_time_us();
|
|
||||||
|
|
||||||
params.nth = n_tasks;
|
params.nth = n_tasks;
|
||||||
|
|
||||||
if (n_tasks == 1) {
|
|
||||||
/* INIT */
|
/* INIT */
|
||||||
if (GGML_OP_HAS_INIT[node->op]) {
|
if (GGML_OP_HAS_INIT[node->op]) {
|
||||||
|
if (ith < n_tasks) {
|
||||||
params.type = GGML_TASK_TYPE_INIT;
|
params.type = GGML_TASK_TYPE_INIT;
|
||||||
ggml_compute_forward(¶ms, node, state);
|
ggml_compute_forward(¶ms, node, state);
|
||||||
}
|
}
|
||||||
|
ggml_barrier(state);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
|
/* COMPUTE */
|
||||||
// they do something more efficient than spinning (?)
|
if (ith < n_tasks) {
|
||||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
params.type = GGML_TASK_TYPE_COMPUTE;
|
||||||
ggml_compute_forward(¶ms, node, state);
|
ggml_compute_forward(¶ms, node, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_barrier(state);
|
||||||
|
|
||||||
|
/* FINALIZE */
|
||||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||||
|
if (params.ith == 0) {
|
||||||
params.type = GGML_TASK_TYPE_FINALIZE;
|
params.type = GGML_TASK_TYPE_FINALIZE;
|
||||||
ggml_compute_forward(¶ms, node, state);
|
ggml_compute_forward(¶ms, node, state);
|
||||||
}
|
}
|
||||||
|
ggml_barrier(state);
|
||||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
task_phase = GGML_TASK_TYPE_INIT;
|
|
||||||
atomic_store(&state->shared->n_active, n_threads);
|
|
||||||
atomic_store(&state->shared->node_n, node_n);
|
|
||||||
atomic_store(&state->shared->node_task, task_phase);
|
|
||||||
} else {
|
|
||||||
ggml_graph_compute_thread_sync_node(&node_n, state, false);
|
|
||||||
ggml_graph_compute_thread_sync_task(&task_phase, state, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if we should stop
|
|
||||||
if (node_n >= cgraph->n_nodes) break;
|
|
||||||
|
|
||||||
/* INIT & COMPUTE */
|
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
|
||||||
const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
|
|
||||||
|
|
||||||
struct ggml_compute_params params = {
|
|
||||||
/*.type =*/ GGML_TASK_TYPE_INIT,
|
|
||||||
/*.ith =*/ state->ith,
|
|
||||||
/*.nth =*/ n_tasks,
|
|
||||||
/*.wsize =*/ cplan->work_size,
|
|
||||||
/*.wdata =*/ cplan->work_data,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (state->ith < n_tasks) {
|
|
||||||
if (GGML_OP_HAS_INIT[node->op]) {
|
|
||||||
ggml_compute_forward(¶ms, node, state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
|
|
||||||
task_phase = GGML_TASK_TYPE_COMPUTE;
|
|
||||||
atomic_store(&state->shared->n_active, n_threads);
|
|
||||||
atomic_store(&state->shared->node_task, task_phase);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// TODO: this sched_yield can have significant impact on the performance - either positive or negative
|
|
||||||
// depending on the workload and the operating system.
|
|
||||||
// since it is not clear what is the best approach, it should potentially become user-configurable
|
|
||||||
// ref: https://github.com/ggerganov/ggml/issues/291
|
|
||||||
// UPD: adding the do_yield flag seems to resolve the issue universally
|
|
||||||
const bool do_yield = node_n < 0 || cgraph->nodes[node_n]->op == GGML_OP_MUL_MAT;
|
|
||||||
ggml_graph_compute_thread_sync_task(&task_phase, state, do_yield);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state->ith < n_tasks) {
|
|
||||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
|
||||||
ggml_compute_forward(¶ms, node, state);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
|
|
||||||
task_phase = GGML_TASK_TYPE_FINALIZE;
|
|
||||||
atomic_store(&state->shared->n_active, n_threads);
|
|
||||||
atomic_store(&state->shared->node_task, task_phase);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ggml_graph_compute_thread_sync_task(&task_phase, state, false);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19396,7 +19315,6 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
||||||
// update the number of threads from the actual number of threads that we got from OpenMP
|
// update the number of threads from the actual number of threads that we got from OpenMP
|
||||||
n_threads = omp_get_num_threads();
|
n_threads = omp_get_num_threads();
|
||||||
workers[0].shared->n_threads = n_threads;
|
workers[0].shared->n_threads = n_threads;
|
||||||
workers[0].shared->n_active = n_threads;
|
|
||||||
}
|
}
|
||||||
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
|
@ -19459,9 +19377,8 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
|
||||||
/*.perf_node_start_cycles =*/ 0,
|
/*.perf_node_start_cycles =*/ 0,
|
||||||
/*.perf_node_start_time_us =*/ 0,
|
/*.perf_node_start_time_us =*/ 0,
|
||||||
/*.n_threads =*/ n_threads,
|
/*.n_threads =*/ n_threads,
|
||||||
/*.n_active =*/ n_threads,
|
/*.n_barrier =*/ 0,
|
||||||
/*.node_n =*/ -1,
|
/*.n_barrier_passed =*/ 0,
|
||||||
/*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
|
|
||||||
/*.abort_callback =*/ NULL,
|
/*.abort_callback =*/ NULL,
|
||||||
/*.abort_callback_data =*/ NULL,
|
/*.abort_callback_data =*/ NULL,
|
||||||
/*.current_chunk; =*/ 0,
|
/*.current_chunk; =*/ 0,
|
||||||
|
|
6
ggml.h
6
ggml.h
|
@ -319,6 +319,12 @@
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -15744,6 +15744,7 @@ Current version: 147
|
||||||
<option value="claude-3-opus-20240229">claude-3-opus</option>
|
<option value="claude-3-opus-20240229">claude-3-opus</option>
|
||||||
<option value="claude-3-sonnet-20240229">claude-3-sonnet</option>
|
<option value="claude-3-sonnet-20240229">claude-3-sonnet</option>
|
||||||
<option value="claude-3-haiku-20240307">claude-3-haiku</option>
|
<option value="claude-3-haiku-20240307">claude-3-haiku</option>
|
||||||
|
<option value="claude-3-5-sonnet-20240620">claude-3.5-sonnet</option>
|
||||||
</select>
|
</select>
|
||||||
<input type="checkbox" id="claudeaddversion" onchange="" checked>
|
<input type="checkbox" id="claudeaddversion" onchange="" checked>
|
||||||
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
|
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
|
||||||
|
|
389
llama.cpp
389
llama.cpp
|
@ -2321,6 +2321,8 @@ struct llama_vocab {
|
||||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
|
||||||
|
int max_token_len = 0; // used for optimizing longest token search
|
||||||
|
|
||||||
std::unordered_map<token, id> token_to_id;
|
std::unordered_map<token, id> token_to_id;
|
||||||
std::vector<token_data> id_to_token;
|
std::vector<token_data> id_to_token;
|
||||||
|
|
||||||
|
@ -2338,21 +2340,23 @@ struct llama_vocab {
|
||||||
id special_cls_id = -1;
|
id special_cls_id = -1;
|
||||||
id special_mask_id = -1;
|
id special_mask_id = -1;
|
||||||
|
|
||||||
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
|
|
||||||
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
|
|
||||||
|
|
||||||
id linefeed_id = 13;
|
id linefeed_id = 13;
|
||||||
id special_prefix_id = -1;
|
id special_prefix_id = -1;
|
||||||
id special_suffix_id = -1;
|
id special_suffix_id = -1;
|
||||||
id special_middle_id = -1;
|
id special_middle_id = -1;
|
||||||
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
||||||
|
|
||||||
bool add_space_prefix = true;
|
// tokenizer flags
|
||||||
|
bool tokenizer_add_space_prefix = true;
|
||||||
|
bool tokenizer_add_bos = false;
|
||||||
|
bool tokenizer_add_eos = false;
|
||||||
|
bool tokenizer_ignore_merges = false;
|
||||||
|
|
||||||
int find_bpe_rank(std::string token_left, std::string token_right) const {
|
int find_bpe_rank(std::string token_left, std::string token_right) const {
|
||||||
// GGML_ASSERT(token_left.find(" ") == std::string::npos);
|
//GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
||||||
// GGML_ASSERT(token_left.find("\n") == std::string::npos);
|
//GGML_ASSERT(token_left.find('\n') == std::string::npos);
|
||||||
// GGML_ASSERT(token_right.find(" ") == std::string::npos);
|
//GGML_ASSERT(token_right.find(' ') == std::string::npos);
|
||||||
// GGML_ASSERT(token_right.find("\n") == std::string::npos);
|
//GGML_ASSERT(token_right.find('\n') == std::string::npos);
|
||||||
//the above breaks gguf v1 falcons
|
//the above breaks gguf v1 falcons
|
||||||
replace_all(token_left, " ", "\u0120");
|
replace_all(token_left, " ", "\u0120");
|
||||||
replace_all(token_left, "\n", "\u010A");
|
replace_all(token_left, "\n", "\u010A");
|
||||||
|
@ -4823,7 +4827,7 @@ static void llm_load_vocab(
|
||||||
|
|
||||||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
||||||
if (add_space_prefix_keyidx != -1) {
|
if (add_space_prefix_keyidx != -1) {
|
||||||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
||||||
} // The default value of add_space_prefix is true.
|
} // The default value of add_space_prefix is true.
|
||||||
} else if (tokenizer_model == "bert") {
|
} else if (tokenizer_model == "bert") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
||||||
|
@ -4836,13 +4840,13 @@ static void llm_load_vocab(
|
||||||
vocab.special_pad_id = 0;
|
vocab.special_pad_id = 0;
|
||||||
vocab.special_cls_id = 101;
|
vocab.special_cls_id = 101;
|
||||||
vocab.special_mask_id = 103;
|
vocab.special_mask_id = 103;
|
||||||
vocab.add_space_prefix = false;
|
vocab.tokenizer_add_space_prefix = false;
|
||||||
} else if (tokenizer_model == "gpt2") {
|
} else if (tokenizer_model == "gpt2") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
|
|
||||||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
||||||
if (add_space_prefix_keyidx != -1) {
|
if (add_space_prefix_keyidx != -1) {
|
||||||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// read bpe merges and populate bpe ranks
|
// read bpe merges and populate bpe ranks
|
||||||
|
@ -4909,6 +4913,8 @@ static void llm_load_vocab(
|
||||||
tokenizer_pre == "llama-v3" ||
|
tokenizer_pre == "llama-v3" ||
|
||||||
tokenizer_pre == "llama-bpe") {
|
tokenizer_pre == "llama-bpe") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
||||||
|
vocab.tokenizer_ignore_merges = true;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "deepseek-llm") {
|
tokenizer_pre == "deepseek-llm") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
|
||||||
|
@ -4959,6 +4965,14 @@ static void llm_load_vocab(
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
|
vocab.tokenizer_add_eos = false;
|
||||||
|
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
|
vocab.tokenizer_add_eos = false;
|
||||||
} else {
|
} else {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
}
|
}
|
||||||
|
@ -4999,6 +5013,7 @@ static void llm_load_vocab(
|
||||||
}
|
}
|
||||||
|
|
||||||
vocab.token_to_id[word] = i;
|
vocab.token_to_id[word] = i;
|
||||||
|
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
||||||
|
|
||||||
auto & token_data = vocab.id_to_token[i];
|
auto & token_data = vocab.id_to_token[i];
|
||||||
token_data.text = std::move(word);
|
token_data.text = std::move(word);
|
||||||
|
@ -5112,10 +5127,10 @@ static void llm_load_vocab(
|
||||||
bool temp = true;
|
bool temp = true;
|
||||||
|
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
||||||
vocab.special_add_bos = int(temp);
|
vocab.tokenizer_add_bos = temp;
|
||||||
}
|
}
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
||||||
vocab.special_add_eos = int(temp);
|
vocab.tokenizer_add_eos = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5215,7 +5230,7 @@ static void llm_load_vocab(
|
||||||
);
|
);
|
||||||
|
|
||||||
// set attributes by model/tokenizer name
|
// set attributes by model/tokenizer name
|
||||||
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
|
if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
|
||||||
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||||
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
||||||
for (auto id : vocab.cache_special_tokens) {
|
for (auto id : vocab.cache_special_tokens) {
|
||||||
|
@ -5309,6 +5324,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
|
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
|
||||||
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
|
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
||||||
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||||
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
||||||
|
@ -7716,6 +7733,50 @@ struct llm_build_context {
|
||||||
return lctx.inp_s_seq;
|
return lctx.inp_s_seq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
|
||||||
|
// find result_norm tensor for input
|
||||||
|
struct ggml_tensor * inp = nullptr;
|
||||||
|
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||||
|
inp = gf->nodes[i];
|
||||||
|
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
inp = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
|
switch (pooling_type) {
|
||||||
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
{
|
||||||
|
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||||
|
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
|
{
|
||||||
|
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||||
|
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
|
{
|
||||||
|
cur = inp;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false && "unknown pooling type");
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(cur, "result_embd_pooled", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_llama() {
|
struct ggml_cgraph * build_llama() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
@ -8696,8 +8757,6 @@ struct llm_build_context {
|
||||||
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
|
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
|
||||||
inp_pos = build_inp_pos();
|
inp_pos = build_inp_pos();
|
||||||
}
|
}
|
||||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
|
||||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
|
||||||
|
|
||||||
// construct input embeddings (token, type, position)
|
// construct input embeddings (token, type, position)
|
||||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
@ -8872,28 +8931,6 @@ struct llm_build_context {
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
cb(cur, "result_embd", -1);
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
// pooling layer
|
|
||||||
switch (pooling_type) {
|
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
|
||||||
{
|
|
||||||
// nop
|
|
||||||
} break;
|
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
|
||||||
{
|
|
||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
|
||||||
cb(cur, "result_embd_pooled", -1);
|
|
||||||
} break;
|
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
|
||||||
{
|
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
|
||||||
cb(cur, "result_embd_pooled", -1);
|
|
||||||
} break;
|
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false && "Invalid pooling type");
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
|
@ -11978,6 +12015,11 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add on pooling layer
|
||||||
|
if (lctx.cparams.embeddings) {
|
||||||
|
result = llm.append_pooling(result);
|
||||||
|
}
|
||||||
|
|
||||||
llm.free();
|
llm.free();
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -12067,7 +12109,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
// (!a || b) is a logical implication (a -> b)
|
// (!a || b) is a logical implication (a -> b)
|
||||||
// !hparams.causal_attn -> !cparams.causal_attn
|
// !hparams.causal_attn -> !cparams.causal_attn
|
||||||
(hparams.causal_attn || !cparams.causal_attn) &&
|
(hparams.causal_attn || !cparams.causal_attn) &&
|
||||||
"causal attention with embedding models is not supported"
|
"causal attention is not supported by this model"
|
||||||
);
|
);
|
||||||
|
|
||||||
if (lctx.inp_KQ_mask) {
|
if (lctx.inp_KQ_mask) {
|
||||||
|
@ -12199,6 +12241,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
||||||
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(lctx.inp_cls);
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||||
|
|
||||||
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||||
|
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||||
|
|
||||||
|
std::vector<int> last_pos(n_tokens, -1);
|
||||||
|
std::vector<int> last_row(n_tokens, -1);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
|
const llama_pos pos = batch.pos[i];
|
||||||
|
|
||||||
|
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
||||||
|
|
||||||
|
if (pos >= last_pos[seq_id]) {
|
||||||
|
last_pos[seq_id] = pos;
|
||||||
|
last_row[seq_id] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
if (last_row[i] >= 0) {
|
||||||
|
data[i] = last_row[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
if (kv_self.recurrent) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
|
@ -12260,8 +12333,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
// TODO: use a per-batch flag for logits presence instead
|
// TODO: use a per-batch flag for logits presence instead
|
||||||
const bool has_logits = cparams.causal_attn;
|
const bool has_logits = !cparams.embeddings;
|
||||||
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||||
|
|
||||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||||
|
@ -12391,11 +12464,13 @@ static int llama_decode_internal(
|
||||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (batch_all.logits) {
|
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
n_outputs = n_tokens_all;
|
||||||
|
} else if (batch_all.logits) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs += batch_all.logits[i] != 0;
|
n_outputs += batch_all.logits[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
} else if (lctx.logits_all) {
|
||||||
n_outputs = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
} else {
|
} else {
|
||||||
// keep last output only
|
// keep last output only
|
||||||
|
@ -12526,30 +12601,13 @@ static int llama_decode_internal(
|
||||||
// no output
|
// no output
|
||||||
res = nullptr;
|
res = nullptr;
|
||||||
embd = nullptr;
|
embd = nullptr;
|
||||||
} else if (!hparams.causal_attn) {
|
|
||||||
res = nullptr; // do not extract logits for embedding models such as BERT
|
|
||||||
|
|
||||||
// token or sequence embeddings
|
|
||||||
embd = gf->nodes[gf->n_nodes - 1];
|
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
|
||||||
} else if (cparams.embeddings) {
|
} else if (cparams.embeddings) {
|
||||||
// the embeddings could be in the second to last tensor, or any of the previous tensors
|
res = nullptr; // do not extract logits for embedding case
|
||||||
int i_embd = gf->n_nodes - 2;
|
embd = gf->nodes[gf->n_nodes - 1];
|
||||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
|
if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
||||||
i_embd = gf->n_nodes - i;
|
embd = gf->nodes[gf->n_nodes - 2];
|
||||||
if (i_embd < 0) { break; }
|
|
||||||
embd = gf->nodes[i_embd];
|
|
||||||
}
|
|
||||||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
|
|
||||||
|
|
||||||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
|
|
||||||
if (!cparams.causal_attn) {
|
|
||||||
res = nullptr; // do not extract logits when not needed
|
|
||||||
// skip computing logits
|
|
||||||
// TODO: is this safe?
|
|
||||||
gf->n_nodes = i_embd + 1;
|
|
||||||
}
|
}
|
||||||
|
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
||||||
} else {
|
} else {
|
||||||
embd = nullptr; // do not extract embeddings when not needed
|
embd = nullptr; // do not extract embeddings when not needed
|
||||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||||
|
@ -12618,11 +12676,10 @@ static int llama_decode_internal(
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
|
||||||
|
|
||||||
// extract sequence embeddings
|
// extract sequence embeddings
|
||||||
auto & embd_seq_out = lctx.embd_seq;
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
embd_seq_out.clear();
|
embd_seq_out.clear();
|
||||||
|
@ -13457,113 +13514,143 @@ private:
|
||||||
///// end legacy functions for Falcon //////
|
///// end legacy functions for Falcon //////
|
||||||
|
|
||||||
struct llm_tokenizer_bpe {
|
struct llm_tokenizer_bpe {
|
||||||
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
|
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
|
||||||
|
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
|
||||||
int final_prev_index = -1;
|
|
||||||
bool ignore_merges = false;
|
|
||||||
|
|
||||||
std::vector<std::string> word_collection;
|
|
||||||
switch (vocab.type) {
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
|
||||||
switch (vocab.type_pre) {
|
switch (vocab.type_pre) {
|
||||||
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
||||||
ignore_merges = true;
|
regex_exprs = {
|
||||||
word_collection = unicode_regex_split(text, {
|
|
||||||
// original regex from tokenizer.json
|
// original regex from tokenizer.json
|
||||||
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
|
||||||
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
|
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
|
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
// same as llama3
|
// same as llama3
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
||||||
"\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
"\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
||||||
"\\s+$",
|
"\\s+$",
|
||||||
"[一-龥ࠀ-一가-]+",
|
"[一-龥ࠀ-一가-]+",
|
||||||
"\\p{N}+",
|
"\\p{N}+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
"\\s?\\p{L}+",
|
"\\s?\\p{L}+",
|
||||||
"\\s?\\p{P}+",
|
"\\s?\\p{P}+",
|
||||||
"[一-龥ࠀ-一가-]+",
|
"[一-龥ࠀ-一가-]+",
|
||||||
"\\p{N}",
|
"\\p{N}",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_FALCON:
|
case LLAMA_VOCAB_PRE_TYPE_FALCON:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
"[\\p{P}\\$\\+<=>\\^~\\|`]+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
||||||
// TODO: MPT pre-tokenization regexes are unknown
|
// TODO: MPT pre-tokenization regexes are unknown
|
||||||
// the following are close, but not exact. run the following:
|
// the following are close, but not exact. run the following:
|
||||||
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
|
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
|
||||||
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
|
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"\\s?\\p{L}+",
|
"\\s?\\p{L}+",
|
||||||
"\\s?\\p{P}+",
|
"\\s?\\p{P}+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"\\p{N}",
|
"\\p{N}",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
// original regex from tokenizer.json
|
// original regex from tokenizer.json
|
||||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_PORO:
|
case LLAMA_VOCAB_PRE_TYPE_PORO:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
"\\p{N}+",
|
"\\p{N}+",
|
||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
|
||||||
|
output.push_back(token_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool append_bos(std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_bos) {
|
||||||
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
|
output.push_back(vocab.special_bos_id);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool append_eos(std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_eos) {
|
||||||
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
|
output.push_back(vocab.special_eos_id);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
||||||
|
LLAMA_LOG_WARN(
|
||||||
|
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||||
|
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||||
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
|
}
|
||||||
|
if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
|
||||||
|
LLAMA_LOG_WARN(
|
||||||
|
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
|
||||||
|
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
|
||||||
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
|
int final_prev_index = -1;
|
||||||
|
|
||||||
|
const auto word_collection = unicode_regex_split(text, regex_exprs);
|
||||||
|
|
||||||
symbols_final.clear();
|
symbols_final.clear();
|
||||||
|
|
||||||
for (auto & word : word_collection) {
|
for (auto & word : word_collection) {
|
||||||
|
@ -13573,7 +13660,7 @@ struct llm_tokenizer_bpe {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
||||||
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
||||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||||
offset = word.size();
|
offset = word.size();
|
||||||
}
|
}
|
||||||
|
@ -13654,10 +13741,9 @@ struct llm_tokenizer_bpe {
|
||||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||||
std::string byte_str(1, *j);
|
std::string byte_str(1, *j);
|
||||||
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
||||||
if (token_multibyte == vocab.token_to_id.end()) {
|
if (token_multibyte != vocab.token_to_id.end()) {
|
||||||
throw std::runtime_error("ERROR: byte not found in vocab");
|
output.push_back(token_multibyte->second);
|
||||||
}
|
}
|
||||||
output.push_back((*token_multibyte).second);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output.push_back((*token).second);
|
output.push_back((*token).second);
|
||||||
|
@ -13696,6 +13782,8 @@ private:
|
||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
|
std::vector<std::string> regex_exprs;
|
||||||
|
|
||||||
std::vector<llm_symbol> symbols;
|
std::vector<llm_symbol> symbols;
|
||||||
std::vector<llm_symbol> symbols_final;
|
std::vector<llm_symbol> symbols_final;
|
||||||
|
|
||||||
|
@ -13705,7 +13793,7 @@ private:
|
||||||
struct llm_tokenizer_wpm {
|
struct llm_tokenizer_wpm {
|
||||||
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
|
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
|
||||||
const auto & token_map = vocab.token_to_id;
|
const auto & token_map = vocab.token_to_id;
|
||||||
|
|
||||||
// normalize and split by whitespace
|
// normalize and split by whitespace
|
||||||
|
@ -13714,7 +13802,7 @@ struct llm_tokenizer_wpm {
|
||||||
// bos token prepended already
|
// bos token prepended already
|
||||||
|
|
||||||
// find the longest tokens that form the words
|
// find the longest tokens that form the words
|
||||||
for (const std::string &word : words) {
|
for (const std::string & word : words) {
|
||||||
// skip empty words
|
// skip empty words
|
||||||
if (word.size() == 0) {
|
if (word.size() == 0) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -13731,7 +13819,7 @@ struct llm_tokenizer_wpm {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
// loop through possible match length
|
// loop through possible match length
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (int j = n; j > i; j--) {
|
for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
|
||||||
auto it = token_map.find(word1.substr(i, j - i));
|
auto it = token_map.find(word1.substr(i, j - i));
|
||||||
if (it != token_map.end()) {
|
if (it != token_map.end()) {
|
||||||
output.push_back(it->second);
|
output.push_back(it->second);
|
||||||
|
@ -13754,7 +13842,8 @@ struct llm_tokenizer_wpm {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> preprocess(const std::string & text) {
|
// TODO: reduce string copies by using cpts_offs array
|
||||||
|
std::vector<std::string> preprocess(const std::string & text) const {
|
||||||
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
||||||
std::vector<std::string> words(1, "");
|
std::vector<std::string> words(1, "");
|
||||||
|
|
||||||
|
@ -13976,7 +14065,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
|
|
||||||
bool is_prev_special = false;
|
bool is_prev_special = false;
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0) {
|
if (add_special && vocab.tokenizer_add_bos) {
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
output.push_back(vocab.special_bos_id);
|
output.push_back(vocab.special_bos_id);
|
||||||
is_prev_special = true;
|
is_prev_special = true;
|
||||||
|
@ -13986,7 +14075,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
|
||||||
if (vocab.add_space_prefix) {
|
if (vocab.tokenizer_add_space_prefix) {
|
||||||
if (!output.size() || is_prev_special) { // prefix with space if first token
|
if (!output.size() || is_prev_special) { // prefix with space if first token
|
||||||
raw_text = " " + raw_text;
|
raw_text = " " + raw_text;
|
||||||
}
|
}
|
||||||
|
@ -14004,24 +14093,52 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
||||||
// LLAMA_LOG_WARN(
|
// LLAMA_LOG_WARN(
|
||||||
// "%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
// "%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||||
// "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
// "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||||
// "Are you sure this is what you want?\n", __FUNCTION__);
|
// "Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_eos == 1) {
|
if (add_special && vocab.tokenizer_add_eos) {
|
||||||
GGML_ASSERT(vocab.special_eos_id != -1);
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
{
|
{
|
||||||
if (add_special && vocab.special_add_bos != 0) {
|
if (OldBPETokenizerMode)
|
||||||
|
{
|
||||||
|
if (add_special && vocab.tokenizer_add_bos != 0)
|
||||||
|
{
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
output.push_back(vocab.special_bos_id);
|
output.push_back(vocab.special_bos_id);
|
||||||
}
|
}
|
||||||
|
for (const auto &fragment : fragment_buffer)
|
||||||
|
{
|
||||||
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||||
|
{
|
||||||
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
llm_tokenizer_bpe_old tokenizer(vocab);
|
||||||
|
tokenizer.tokenize(raw_text, output);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
output.push_back(fragment.token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_special && vocab.tokenizer_add_eos == 1)
|
||||||
|
{
|
||||||
|
output.push_back(vocab.special_eos_id);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_tokenizer_bpe tokenizer(vocab);
|
||||||
|
|
||||||
|
if (add_special) {
|
||||||
|
tokenizer.append_bos(output);
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
|
@ -14031,32 +14148,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if(OldBPETokenizerMode)
|
|
||||||
{
|
|
||||||
llm_tokenizer_bpe_old tokenizer(vocab);
|
|
||||||
tokenizer.tokenize(raw_text, output);
|
tokenizer.tokenize(raw_text, output);
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
llm_tokenizer_bpe tokenizer(vocab);
|
|
||||||
tokenizer.tokenize(raw_text, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
tokenizer.append(fragment.token, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
if (add_special) {
|
||||||
// LLAMA_LOG_WARN(
|
tokenizer.append_eos(output);
|
||||||
// "%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
tokenizer.check_double_bos_eos(output);
|
||||||
// "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
|
||||||
// "Are you sure this is what you want?\n", __FUNCTION__);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (add_special && vocab.special_add_eos == 1) {
|
|
||||||
GGML_ASSERT(vocab.special_add_eos != -1);
|
|
||||||
output.push_back(vocab.special_eos_id);
|
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
|
@ -14066,6 +14166,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
output.push_back(vocab.special_cls_id);
|
output.push_back(vocab.special_cls_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_tokenizer_wpm tokenizer(vocab);
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
@ -14073,7 +14175,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
llm_tokenizer_wpm tokenizer(vocab);
|
|
||||||
tokenizer.tokenize(raw_text, output);
|
tokenizer.tokenize(raw_text, output);
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
|
@ -18399,6 +18500,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
||||||
ctx->abort_callback_data = abort_callback_data;
|
ctx->abort_callback_data = abort_callback_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
||||||
|
ctx->cparams.embeddings = embeddings;
|
||||||
|
}
|
||||||
|
|
||||||
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
||||||
ctx->cparams.causal_attn = causal_attn;
|
ctx->cparams.causal_attn = causal_attn;
|
||||||
}
|
}
|
||||||
|
@ -18642,11 +18747,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_add_bos_token(const struct llama_model * model) {
|
int32_t llama_add_bos_token(const struct llama_model * model) {
|
||||||
return model->vocab.special_add_bos;
|
return model->vocab.tokenizer_add_bos;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_add_eos_token(const struct llama_model * model) {
|
int32_t llama_add_eos_token(const struct llama_model * model) {
|
||||||
return model->vocab.special_add_eos;
|
return model->vocab.tokenizer_add_eos;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_token_prefix(const struct llama_model * model) {
|
llama_token llama_token_prefix(const struct llama_model * model) {
|
||||||
|
|
6
llama.h
6
llama.h
|
@ -174,6 +174,7 @@ extern "C" {
|
||||||
LLAMA_POOLING_TYPE_NONE = 0,
|
LLAMA_POOLING_TYPE_NONE = 0,
|
||||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||||
LLAMA_POOLING_TYPE_CLS = 2,
|
LLAMA_POOLING_TYPE_CLS = 2,
|
||||||
|
LLAMA_POOLING_TYPE_LAST = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_split_mode {
|
enum llama_split_mode {
|
||||||
|
@ -293,7 +294,6 @@ extern "C" {
|
||||||
|
|
||||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||||
// (ignored if no pooling layer)
|
|
||||||
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||||
|
@ -788,6 +788,10 @@ extern "C" {
|
||||||
// Get the number of threads used for prompt and batch processing (multiple token).
|
// Get the number of threads used for prompt and batch processing (multiple token).
|
||||||
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Set whether the model is in embeddings model or not
|
||||||
|
// If true, embeddings will be returned but logits will not
|
||||||
|
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||||
|
|
||||||
// Set whether to use causal attention or not
|
// Set whether to use causal attention or not
|
||||||
// If set to true, the model will only attend to the past tokens
|
// If set to true, the model will only attend to the past tokens
|
||||||
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||||
|
|
|
@ -1,83 +1,143 @@
|
||||||
import regex
|
import array
|
||||||
import ctypes
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
import requests
|
||||||
|
|
||||||
class CoodepointFlags (ctypes.Structure):
|
|
||||||
_fields_ = [ # see definition in unicode.h
|
|
||||||
("is_undefined", ctypes.c_uint16, 1),
|
|
||||||
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
|
|
||||||
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
|
|
||||||
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
|
|
||||||
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
|
|
||||||
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
|
|
||||||
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
|
|
||||||
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
assert (ctypes.sizeof(CoodepointFlags) == 2)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_CODEPOINTS = 0x110000
|
MAX_CODEPOINTS = 0x110000
|
||||||
|
|
||||||
regex_number = regex.compile(r'\p{N}')
|
UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
|
||||||
regex_letter = regex.compile(r'\p{L}')
|
|
||||||
regex_separator = regex.compile(r'\p{Z}')
|
|
||||||
regex_accent_mark = regex.compile(r'\p{M}')
|
|
||||||
regex_punctuation = regex.compile(r'\p{P}')
|
|
||||||
regex_symbol = regex.compile(r'\p{S}')
|
|
||||||
regex_control = regex.compile(r'\p{C}')
|
|
||||||
regex_whitespace = regex.compile(r'\s')
|
|
||||||
|
|
||||||
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
|
|
||||||
|
# see https://www.unicode.org/L2/L1999/UnicodeData.html
|
||||||
|
def unicode_data_iter():
|
||||||
|
res = requests.get(UNICODE_DATA_URL)
|
||||||
|
res.raise_for_status()
|
||||||
|
data = res.content.decode()
|
||||||
|
|
||||||
|
prev = []
|
||||||
|
|
||||||
|
for line in data.splitlines():
|
||||||
|
# ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
|
||||||
|
line = line.split(";")
|
||||||
|
|
||||||
|
cpt = int(line[0], base=16)
|
||||||
|
assert cpt < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
cpt_lower = int(line[-2] or "0", base=16)
|
||||||
|
assert cpt_lower < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
cpt_upper = int(line[-3] or "0", base=16)
|
||||||
|
assert cpt_upper < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
categ = line[2].strip()
|
||||||
|
assert len(categ) == 2
|
||||||
|
|
||||||
|
bidir = line[4].strip()
|
||||||
|
assert len(categ) == 2
|
||||||
|
|
||||||
|
name = line[1]
|
||||||
|
if name.endswith(", First>"):
|
||||||
|
prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
continue
|
||||||
|
if name.endswith(", Last>"):
|
||||||
|
assert prev[1:] == (0, 0, categ, bidir)
|
||||||
|
for c in range(prev[0], cpt):
|
||||||
|
yield (c, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
|
||||||
|
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
|
||||||
|
|
||||||
|
# see definition in unicode.h
|
||||||
|
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
|
||||||
|
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
|
||||||
|
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
|
||||||
|
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
|
||||||
|
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
|
||||||
|
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
|
||||||
|
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
|
||||||
|
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
|
||||||
|
|
||||||
|
UNICODE_CATEGORY_TO_FLAG = {
|
||||||
|
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
|
||||||
|
"Cc": CODEPOINT_FLAG_CONTROL, # Control
|
||||||
|
"Cf": CODEPOINT_FLAG_CONTROL, # Format
|
||||||
|
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
|
||||||
|
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
|
||||||
|
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
|
||||||
|
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
|
||||||
|
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
|
||||||
|
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
|
||||||
|
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
|
||||||
|
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
|
||||||
|
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
|
||||||
|
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
|
||||||
|
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
|
||||||
|
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
|
||||||
|
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
|
||||||
|
"No": CODEPOINT_FLAG_NUMBER, # Other Number
|
||||||
|
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
|
||||||
|
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
|
||||||
|
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
|
||||||
|
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
|
||||||
|
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
|
||||||
|
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
|
||||||
|
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
|
||||||
|
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
|
||||||
|
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
|
||||||
|
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
|
||||||
|
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
|
||||||
|
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
|
||||||
|
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
|
||||||
|
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
|
||||||
table_whitespace = []
|
table_whitespace = []
|
||||||
table_lowercase = []
|
table_lowercase = []
|
||||||
table_uppercase = []
|
table_uppercase = []
|
||||||
table_nfd = []
|
table_nfd = []
|
||||||
|
|
||||||
for codepoint in range(MAX_CODEPOINTS):
|
for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
|
||||||
# convert codepoint to unicode character
|
# convert codepoint to unicode character
|
||||||
char = chr(codepoint)
|
char = chr(cpt)
|
||||||
|
|
||||||
# regex categories
|
# codepoint category flags
|
||||||
flags = codepoint_flags[codepoint]
|
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
|
||||||
flags.is_number = bool(regex_number.match(char))
|
|
||||||
flags.is_letter = bool(regex_letter.match(char))
|
|
||||||
flags.is_separator = bool(regex_separator.match(char))
|
|
||||||
flags.is_accent_mark = bool(regex_accent_mark.match(char))
|
|
||||||
flags.is_punctuation = bool(regex_punctuation.match(char))
|
|
||||||
flags.is_symbol = bool(regex_symbol.match(char))
|
|
||||||
flags.is_control = bool(regex_control.match(char))
|
|
||||||
flags.is_undefined = bytes(flags)[0] == 0
|
|
||||||
assert (not flags.is_undefined)
|
|
||||||
|
|
||||||
# whitespaces
|
|
||||||
if bool(regex_whitespace.match(char)):
|
|
||||||
table_whitespace.append(codepoint)
|
|
||||||
|
|
||||||
# lowercase conversion
|
# lowercase conversion
|
||||||
lower = ord(char.lower()[0])
|
if cpt_lower:
|
||||||
if codepoint != lower:
|
table_lowercase.append((cpt, cpt_lower))
|
||||||
table_lowercase.append((codepoint, lower))
|
|
||||||
|
|
||||||
# uppercase conversion
|
# uppercase conversion
|
||||||
upper = ord(char.upper()[0])
|
if cpt_upper:
|
||||||
if codepoint != upper:
|
table_uppercase.append((cpt, cpt_upper))
|
||||||
table_uppercase.append((codepoint, upper))
|
|
||||||
|
|
||||||
# NFD normalization
|
# NFD normalization
|
||||||
norm = ord(unicodedata.normalize('NFD', char)[0])
|
norm = ord(unicodedata.normalize('NFD', char)[0])
|
||||||
if codepoint != norm:
|
if cpt != norm:
|
||||||
table_nfd.append((codepoint, norm))
|
table_nfd.append((cpt, norm))
|
||||||
|
|
||||||
|
|
||||||
|
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
|
||||||
|
table_whitespace.extend(range(0x0009, 0x000D + 1))
|
||||||
|
table_whitespace.extend(range(0x2000, 0x200A + 1))
|
||||||
|
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
|
||||||
|
|
||||||
|
|
||||||
|
# sort by codepoint
|
||||||
|
table_whitespace.sort()
|
||||||
|
table_lowercase.sort()
|
||||||
|
table_uppercase.sort()
|
||||||
|
table_nfd.sort()
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same flags
|
# group ranges with same flags
|
||||||
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||||
for codepoint, flags in enumerate(codepoint_flags):
|
for codepoint, flags in enumerate(codepoint_flags):
|
||||||
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
if flags != ranges_flags[-1][1]:
|
||||||
ranges_flags.append((codepoint, flags))
|
ranges_flags.append((codepoint, flags))
|
||||||
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
|
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same nfd
|
# group ranges with same nfd
|
||||||
|
@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
|
||||||
ranges_nfd[-1] = (start, codepoint, norm)
|
ranges_nfd[-1] = (start, codepoint, norm)
|
||||||
|
|
||||||
|
|
||||||
# Generate 'unicode-data.cpp'
|
# Generate 'unicode-data.cpp':
|
||||||
|
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
|
||||||
|
|
||||||
def out(line=""):
|
def out(line=""):
|
||||||
print(line, end='\n') # noqa
|
print(line, end='\n') # noqa
|
||||||
|
@ -110,12 +170,12 @@ out("""\
|
||||||
|
|
||||||
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
||||||
for codepoint, flags in ranges_flags:
|
for codepoint, flags in ranges_flags:
|
||||||
flags = int.from_bytes(bytes(flags), "little")
|
|
||||||
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
||||||
out("};\n")
|
out("};\n")
|
||||||
|
|
||||||
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
||||||
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
|
for codepoint in table_whitespace:
|
||||||
|
out("0x%06X," % codepoint)
|
||||||
out("};\n")
|
out("};\n")
|
||||||
|
|
||||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||||
|
|
|
@ -43,8 +43,10 @@
|
||||||
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
||||||
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
||||||
|
|
||||||
|
#if defined(__GNUC__)
|
||||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "sgemm.h"
|
#include "sgemm.h"
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
|
|
1652
unicode-data.cpp
1652
unicode-data.cpp
File diff suppressed because it is too large
Load diff
30
unicode.cpp
30
unicode.cpp
|
@ -226,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
||||||
assert(offset_end <= cpts.size());
|
assert(offset_end <= cpts.size());
|
||||||
start = offset_end;
|
start = offset_end;
|
||||||
|
|
||||||
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
|
@ -309,7 +310,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
||||||
}
|
}
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
// regex: \s+(?!\S)
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
pos += num_whitespaces - 1;
|
pos += num_whitespaces - 1;
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
|
@ -344,8 +345,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||||
assert(offset_end <= cpts.size());
|
assert(offset_end <= cpts.size());
|
||||||
start = offset_end;
|
start = offset_end;
|
||||||
|
|
||||||
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
|
@ -450,7 +452,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
||||||
}
|
}
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
// regex: \s+(?!\S)
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
pos += num_whitespaces - 1;
|
pos += num_whitespaces - 1;
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
|
@ -594,6 +596,7 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
||||||
std::vector<uint32_t> result;
|
std::vector<uint32_t> result;
|
||||||
|
result.reserve(utf8.size());
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
while (offset < utf8.size()) {
|
while (offset < utf8.size()) {
|
||||||
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
||||||
|
@ -679,10 +682,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
const auto flags = unicode_cpt_flags(cpts[i]);
|
||||||
|
|
||||||
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
if (flags.is_whitespace) {
|
||||||
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||||
|
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||||
|
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||||
|
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||||
|
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||||
} else {
|
} else {
|
||||||
text_collapsed[i] = (char) 0xD0; // fallback
|
text_collapsed[i] = (char) 0xD0; // fallback
|
||||||
}
|
}
|
||||||
|
@ -766,9 +773,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
||||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||||
} else {
|
} else {
|
||||||
// no unicode category used, we can use std::wregex directly
|
// no unicode category used, we can use std::wregex directly
|
||||||
const std::wstring wtext = unicode_wstring_from_utf8(text);
|
|
||||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||||
|
|
||||||
|
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||||
|
std::wstring wtext(cpts.begin(), cpts.end());
|
||||||
|
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||||
|
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
||||||
|
wtext[i] = 0x0B;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//printf("text: %s\n", text.c_str());
|
//printf("text: %s\n", text.c_str());
|
||||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue